mirror of https://github.com/tracel-ai/burn.git
Refactor index => slice (#466)
This commit is contained in:
parent
042d2201d2
commit
65bf6c1cbb
|
@ -44,11 +44,11 @@ impl<B: Backend> BoolTensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B>
|
|||
B::bool_reshape(tensor, shape)
|
||||
}
|
||||
|
||||
fn bool_index<const D1: usize, const D2: usize>(
|
||||
fn bool_slice<const D1: usize, const D2: usize>(
|
||||
tensor: BoolTensor<B, D1>,
|
||||
indexes: [std::ops::Range<usize>; D2],
|
||||
ranges: [std::ops::Range<usize>; D2],
|
||||
) -> BoolTensor<B, D1> {
|
||||
B::bool_index(tensor, indexes)
|
||||
B::bool_slice(tensor, ranges)
|
||||
}
|
||||
|
||||
fn bool_empty<const D: usize>(
|
||||
|
@ -58,12 +58,12 @@ impl<B: Backend> BoolTensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B>
|
|||
B::bool_empty(shape, device)
|
||||
}
|
||||
|
||||
fn bool_index_assign<const D1: usize, const D2: usize>(
|
||||
tensor: <ADBackendDecorator<B> as Backend>::BoolTensorPrimitive<D1>,
|
||||
indexes: [std::ops::Range<usize>; D2],
|
||||
value: <ADBackendDecorator<B> as Backend>::BoolTensorPrimitive<D1>,
|
||||
) -> <ADBackendDecorator<B> as Backend>::BoolTensorPrimitive<D1> {
|
||||
B::bool_index_assign(tensor, indexes, value)
|
||||
fn bool_slice_assign<const D1: usize, const D2: usize>(
|
||||
tensor: BoolTensor<Self, D1>,
|
||||
ranges: [std::ops::Range<usize>; D2],
|
||||
value: BoolTensor<Self, D1>,
|
||||
) -> BoolTensor<Self, D1> {
|
||||
B::bool_slice_assign(tensor, ranges, value)
|
||||
}
|
||||
|
||||
fn bool_cat<const D: usize>(tensors: Vec<BoolTensor<B, D>>, dim: usize) -> BoolTensor<B, D> {
|
||||
|
|
|
@ -43,11 +43,11 @@ impl<B: Backend> IntTensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
B::int_reshape(tensor, shape)
|
||||
}
|
||||
|
||||
fn int_index<const D1: usize, const D2: usize>(
|
||||
fn int_slice<const D1: usize, const D2: usize>(
|
||||
tensor: IntTensor<B, D1>,
|
||||
indexes: [std::ops::Range<usize>; D2],
|
||||
ranges: [std::ops::Range<usize>; D2],
|
||||
) -> IntTensor<B, D1> {
|
||||
B::int_index(tensor, indexes)
|
||||
B::int_slice(tensor, ranges)
|
||||
}
|
||||
|
||||
fn int_empty<const D: usize>(
|
||||
|
@ -57,12 +57,12 @@ impl<B: Backend> IntTensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
B::int_empty(shape, device)
|
||||
}
|
||||
|
||||
fn int_index_assign<const D1: usize, const D2: usize>(
|
||||
fn int_slice_assign<const D1: usize, const D2: usize>(
|
||||
tensor: IntTensor<B, D1>,
|
||||
indexes: [std::ops::Range<usize>; D2],
|
||||
ranges: [std::ops::Range<usize>; D2],
|
||||
value: IntTensor<B, D1>,
|
||||
) -> IntTensor<B, D1> {
|
||||
B::int_index_assign(tensor, indexes, value)
|
||||
B::int_slice_assign(tensor, ranges, value)
|
||||
}
|
||||
|
||||
fn int_cat<const D: usize>(tensors: Vec<IntTensor<B, D>>, dim: usize) -> IntTensor<B, D> {
|
||||
|
@ -198,35 +198,35 @@ impl<B: Backend> IntTensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
fn int_gather<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: IntTensor<B, D>,
|
||||
indexes: IntTensor<B, D>,
|
||||
indices: IntTensor<B, D>,
|
||||
) -> IntTensor<B, D> {
|
||||
B::int_gather(dim, tensor, indexes)
|
||||
B::int_gather(dim, tensor, indices)
|
||||
}
|
||||
|
||||
fn int_scatter<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: IntTensor<B, D>,
|
||||
indexes: IntTensor<B, D>,
|
||||
indices: IntTensor<B, D>,
|
||||
value: IntTensor<B, D>,
|
||||
) -> IntTensor<B, D> {
|
||||
B::int_scatter(dim, tensor, indexes, value)
|
||||
B::int_scatter(dim, tensor, indices, value)
|
||||
}
|
||||
|
||||
fn int_index_select_dim<const D: usize>(
|
||||
fn int_select<const D: usize>(
|
||||
tensor: IntTensor<B, D>,
|
||||
dim: usize,
|
||||
indexes: IntTensor<B, 1>,
|
||||
indices: IntTensor<B, 1>,
|
||||
) -> IntTensor<B, D> {
|
||||
B::int_index_select_dim(tensor, dim, indexes)
|
||||
B::int_select(tensor, dim, indices)
|
||||
}
|
||||
|
||||
fn int_index_select_dim_assign<const D1: usize, const D2: usize>(
|
||||
tensor: IntTensor<B, D1>,
|
||||
fn int_select_assign<const D: usize>(
|
||||
tensor: IntTensor<B, D>,
|
||||
dim: usize,
|
||||
indexes: IntTensor<B, 1>,
|
||||
value: IntTensor<B, D2>,
|
||||
) -> IntTensor<B, D1> {
|
||||
B::int_index_select_dim_assign(tensor, dim, indexes, value)
|
||||
indices: IntTensor<B, 1>,
|
||||
value: IntTensor<B, D>,
|
||||
) -> IntTensor<B, D> {
|
||||
B::int_select_assign(tensor, dim, indices, value)
|
||||
}
|
||||
|
||||
fn int_mask_where<const D: usize>(
|
||||
|
@ -260,11 +260,11 @@ impl<B: Backend> IntTensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
) -> B::IntTensorPrimitive<D> {
|
||||
B::int_max_dim(tensor, dim)
|
||||
}
|
||||
fn int_max_dim_with_indexes<const D: usize>(
|
||||
fn int_max_dim_with_indices<const D: usize>(
|
||||
tensor: B::IntTensorPrimitive<D>,
|
||||
dim: usize,
|
||||
) -> (B::IntTensorPrimitive<D>, B::IntTensorPrimitive<D>) {
|
||||
B::int_max_dim_with_indexes(tensor, dim)
|
||||
B::int_max_dim_with_indices(tensor, dim)
|
||||
}
|
||||
fn int_min<const D: usize>(tensor: B::IntTensorPrimitive<D>) -> B::IntTensorPrimitive<1> {
|
||||
B::int_min(tensor)
|
||||
|
@ -275,10 +275,10 @@ impl<B: Backend> IntTensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
) -> B::IntTensorPrimitive<D> {
|
||||
B::int_min_dim(tensor, dim)
|
||||
}
|
||||
fn int_min_dim_with_indexes<const D: usize>(
|
||||
fn int_min_dim_with_indices<const D: usize>(
|
||||
tensor: B::IntTensorPrimitive<D>,
|
||||
dim: usize,
|
||||
) -> (B::IntTensorPrimitive<D>, B::IntTensorPrimitive<D>) {
|
||||
B::int_min_dim_with_indexes(tensor, dim)
|
||||
B::int_min_dim_with_indices(tensor, dim)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,11 +10,11 @@ impl<B: Backend, const D: usize> Backward<B, D, 1> for MaxMinDim {
|
|||
|
||||
fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
|
||||
unary::<B, D, D, _>(ops.parents, ops.node, grads, |grad| {
|
||||
let (indexes, shape) = ops.state;
|
||||
let (indices, shape) = ops.state;
|
||||
let device = B::device(&grad);
|
||||
let zeros = B::zeros(shape, &device);
|
||||
|
||||
B::scatter(D - 1, zeros, indexes, grad)
|
||||
B::scatter(D - 1, zeros, indices, grad)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9,7 +9,7 @@ use burn_tensor::ops::*;
|
|||
use super::OpsKind;
|
||||
|
||||
impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
||||
fn embedding(weights: ADTensor<B, 2>, indexes: IntTensor<B, 2>) -> ADTensor<B, 3> {
|
||||
fn embedding(weights: ADTensor<B, 2>, indices: IntTensor<B, 2>) -> ADTensor<B, 3> {
|
||||
#[derive(Debug)]
|
||||
struct Embedding;
|
||||
|
||||
|
@ -17,10 +17,10 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
type State = (B::TensorPrimitive<2>, IntTensor<B, 2>);
|
||||
|
||||
fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
|
||||
let (weights, indexes) = ops.state;
|
||||
let (weights, indices) = ops.state;
|
||||
|
||||
unary::<B, 3, 2, _>(ops.parents, ops.node, grads, |grad| {
|
||||
B::embedding_backward(weights, grad, indexes)
|
||||
B::embedding_backward(weights, grad, indices)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
@ -30,19 +30,19 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
.statefull()
|
||||
{
|
||||
OpsKind::Tracked(prep) => prep.finish(
|
||||
(weights.primitive.clone(), indexes.clone()),
|
||||
B::embedding(weights.primitive, indexes),
|
||||
(weights.primitive.clone(), indices.clone()),
|
||||
B::embedding(weights.primitive, indices),
|
||||
),
|
||||
OpsKind::UnTracked(prep) => prep.finish(B::embedding(weights.primitive, indexes)),
|
||||
OpsKind::UnTracked(prep) => prep.finish(B::embedding(weights.primitive, indices)),
|
||||
}
|
||||
}
|
||||
|
||||
fn embedding_backward(
|
||||
weights: ADTensor<B, 2>,
|
||||
output: ADTensor<B, 3>,
|
||||
indexes: IntTensor<B, 2>,
|
||||
indices: IntTensor<B, 2>,
|
||||
) -> ADTensor<B, 2> {
|
||||
let tensor = B::embedding_backward(weights.primitive, output.primitive, indexes);
|
||||
let tensor = B::embedding_backward(weights.primitive, output.primitive, indices);
|
||||
ADTensor::new(tensor)
|
||||
}
|
||||
|
||||
|
@ -360,9 +360,9 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
match MaxPool2D.prepare([x.node], [x.graph]).statefull() {
|
||||
OpsKind::Tracked(prep) => {
|
||||
let output =
|
||||
B::max_pool2d_with_indexes(x.primitive.clone(), kernel_size, stride, padding);
|
||||
B::max_pool2d_with_indices(x.primitive.clone(), kernel_size, stride, padding);
|
||||
prep.finish(
|
||||
(x.primitive, output.indexes, kernel_size, stride, padding),
|
||||
(x.primitive, output.indices, kernel_size, stride, padding),
|
||||
output.output,
|
||||
)
|
||||
}
|
||||
|
@ -372,21 +372,21 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
}
|
||||
}
|
||||
|
||||
fn max_pool2d_with_indexes(
|
||||
fn max_pool2d_with_indices(
|
||||
x: ADTensor<B, 4>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
) -> MaxPool2dWithIndexes<ADBackendDecorator<B>> {
|
||||
) -> MaxPool2dWithIndices<ADBackendDecorator<B>> {
|
||||
match MaxPool2D.prepare([x.node], [x.graph]).statefull() {
|
||||
OpsKind::Tracked(prep) => {
|
||||
let output =
|
||||
B::max_pool2d_with_indexes(x.primitive.clone(), kernel_size, stride, padding);
|
||||
B::max_pool2d_with_indices(x.primitive.clone(), kernel_size, stride, padding);
|
||||
|
||||
let output_tensor = prep.finish(
|
||||
(
|
||||
x.primitive,
|
||||
output.indexes.clone(),
|
||||
output.indices.clone(),
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
|
@ -394,32 +394,32 @@ impl<B: Backend> ModuleOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
output.output,
|
||||
);
|
||||
|
||||
MaxPool2dWithIndexes::new(output_tensor, output.indexes)
|
||||
MaxPool2dWithIndices::new(output_tensor, output.indices)
|
||||
}
|
||||
OpsKind::UnTracked(prep) => {
|
||||
let output = B::max_pool2d_with_indexes(x.primitive, kernel_size, stride, padding);
|
||||
let output = B::max_pool2d_with_indices(x.primitive, kernel_size, stride, padding);
|
||||
let output_tensor = prep.finish(output.output);
|
||||
|
||||
MaxPool2dWithIndexes::new(output_tensor, output.indexes)
|
||||
MaxPool2dWithIndices::new(output_tensor, output.indices)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn max_pool2d_with_indexes_backward(
|
||||
fn max_pool2d_with_indices_backward(
|
||||
x: ADTensor<B, 4>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
output_grad: ADTensor<B, 4>,
|
||||
indexes: IntTensor<B, 4>,
|
||||
indices: IntTensor<B, 4>,
|
||||
) -> MaxPool2dBackward<ADBackendDecorator<B>> {
|
||||
let output = B::max_pool2d_with_indexes_backward(
|
||||
let output = B::max_pool2d_with_indices_backward(
|
||||
x.primitive,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
output_grad.primitive,
|
||||
indexes,
|
||||
indices,
|
||||
);
|
||||
MaxPool2dBackward::new(ADTensor::new(output.x_grad))
|
||||
}
|
||||
|
@ -440,11 +440,11 @@ impl<B: Backend> Backward<B, 4, 1> for MaxPool2D {
|
|||
fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
|
||||
let [node_parent] = ops.parents;
|
||||
let grad = grads.consume::<B, 4>(&ops.node);
|
||||
let (x, indexes, kernel_size, stride, padding) = ops.state;
|
||||
let (x, indices, kernel_size, stride, padding) = ops.state;
|
||||
|
||||
if let Some(node) = node_parent {
|
||||
let grad =
|
||||
B::max_pool2d_with_indexes_backward(x, kernel_size, stride, padding, grad, indexes);
|
||||
B::max_pool2d_with_indices_backward(x, kernel_size, stride, padding, grad, indices);
|
||||
|
||||
grads.register::<B, 4>(node, grad.x_grad);
|
||||
}
|
||||
|
|
|
@ -441,7 +441,7 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
fn gather<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: ADTensor<B, D>,
|
||||
indexes: IntTensor<B, D>,
|
||||
indices: IntTensor<B, D>,
|
||||
) -> ADTensor<B, D> {
|
||||
#[derive(Debug)]
|
||||
struct Gather;
|
||||
|
@ -450,11 +450,11 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
type State = (usize, IntTensor<B, D>, Shape<D>, B::Device);
|
||||
|
||||
fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
|
||||
let (dim, indexes, shape, device) = ops.state;
|
||||
let (dim, indices, shape, device) = ops.state;
|
||||
|
||||
unary::<B, D, D, _>(ops.parents, ops.node, grads, |grad| {
|
||||
let zeros = B::zeros(shape, &device);
|
||||
B::scatter(dim, zeros, indexes, grad)
|
||||
B::scatter(dim, zeros, indices, grad)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
@ -463,20 +463,20 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
OpsKind::Tracked(prep) => prep.finish(
|
||||
(
|
||||
dim,
|
||||
indexes.clone(),
|
||||
indices.clone(),
|
||||
B::shape(&tensor.primitive),
|
||||
B::device(&tensor.primitive),
|
||||
),
|
||||
B::gather(dim, tensor.primitive, indexes),
|
||||
B::gather(dim, tensor.primitive, indices),
|
||||
),
|
||||
OpsKind::UnTracked(prep) => prep.finish(B::gather(dim, tensor.primitive, indexes)),
|
||||
OpsKind::UnTracked(prep) => prep.finish(B::gather(dim, tensor.primitive, indices)),
|
||||
}
|
||||
}
|
||||
|
||||
fn scatter<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: ADTensor<B, D>,
|
||||
indexes: IntTensor<B, D>,
|
||||
indices: IntTensor<B, D>,
|
||||
value: ADTensor<B, D>,
|
||||
) -> ADTensor<B, D> {
|
||||
#[derive(Debug)]
|
||||
|
@ -486,8 +486,8 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
type State = (usize, IntTensor<B, D>, Shape<D>, Shape<D>, B::Device);
|
||||
|
||||
fn backward(self, ops: Ops<Self::State, 2>, grads: &mut Gradients) {
|
||||
let (dim, indexes, shape_lhs, shape_rhs, device) = ops.state;
|
||||
let [indexes_4lhs, indexes_4rhs] = duplicate(&ops.parents, Some(indexes));
|
||||
let (dim, indices, shape_lhs, shape_rhs, device) = ops.state;
|
||||
let [indices_4lhs, indices_4rhs] = duplicate(&ops.parents, Some(indices));
|
||||
|
||||
binary::<B, D, D, D, _, _>(
|
||||
ops.parents,
|
||||
|
@ -495,11 +495,11 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
grads,
|
||||
|grad| {
|
||||
let zeros = B::zeros(shape_lhs, &device);
|
||||
B::scatter(dim, grad, indexes_4lhs.unwrap(), zeros)
|
||||
B::scatter(dim, grad, indices_4lhs.unwrap(), zeros)
|
||||
},
|
||||
|grad| {
|
||||
let zeros = B::zeros(shape_rhs, &device);
|
||||
B::scatter(dim, zeros, indexes_4rhs.unwrap(), grad)
|
||||
B::scatter(dim, zeros, indices_4rhs.unwrap(), grad)
|
||||
},
|
||||
);
|
||||
}
|
||||
|
@ -512,23 +512,23 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
OpsKind::Tracked(prep) => prep.finish(
|
||||
(
|
||||
dim,
|
||||
indexes.clone(),
|
||||
indices.clone(),
|
||||
B::shape(&tensor.primitive),
|
||||
B::shape(&value.primitive),
|
||||
B::device(&value.primitive),
|
||||
),
|
||||
B::scatter(dim, tensor.primitive, indexes, value.primitive),
|
||||
B::scatter(dim, tensor.primitive, indices, value.primitive),
|
||||
),
|
||||
OpsKind::UnTracked(prep) => {
|
||||
prep.finish(B::scatter(dim, tensor.primitive, indexes, value.primitive))
|
||||
prep.finish(B::scatter(dim, tensor.primitive, indices, value.primitive))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn index_select<const D: usize>(
|
||||
fn select<const D: usize>(
|
||||
tensor: ADTensor<B, D>,
|
||||
dim: usize,
|
||||
indexes: IntTensor<B, 1>,
|
||||
indices: IntTensor<B, 1>,
|
||||
) -> ADTensor<B, D> {
|
||||
#[derive(Debug)]
|
||||
struct IndexSelectDim;
|
||||
|
@ -537,11 +537,11 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
type State = (usize, IntTensor<B, 1>, Shape<D>, B::Device);
|
||||
|
||||
fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
|
||||
let (dim, indexes, shape, device) = ops.state;
|
||||
let (dim, indices, shape, device) = ops.state;
|
||||
|
||||
unary::<B, D, D, _>(ops.parents, ops.node, grads, |grad| {
|
||||
let zeros = B::zeros(shape, &device);
|
||||
B::index_select_assign(zeros, dim, indexes, grad)
|
||||
B::select_assign(zeros, dim, indices, grad)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
@ -553,76 +553,74 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
OpsKind::Tracked(prep) => prep.finish(
|
||||
(
|
||||
dim,
|
||||
indexes.clone(),
|
||||
indices.clone(),
|
||||
B::shape(&tensor.primitive),
|
||||
B::device(&tensor.primitive),
|
||||
),
|
||||
B::index_select(tensor.primitive, dim, indexes),
|
||||
B::select(tensor.primitive, dim, indices),
|
||||
),
|
||||
OpsKind::UnTracked(prep) => {
|
||||
prep.finish(B::index_select(tensor.primitive, dim, indexes))
|
||||
}
|
||||
OpsKind::UnTracked(prep) => prep.finish(B::select(tensor.primitive, dim, indices)),
|
||||
}
|
||||
}
|
||||
|
||||
fn index_select_assign<const D1: usize, const D2: usize>(
|
||||
tensor: ADTensor<B, D1>,
|
||||
fn select_assign<const D: usize>(
|
||||
tensor: ADTensor<B, D>,
|
||||
dim: usize,
|
||||
indexes: IntTensor<B, 1>,
|
||||
value: ADTensor<B, D2>,
|
||||
) -> ADTensor<B, D1> {
|
||||
indices: IntTensor<B, 1>,
|
||||
value: ADTensor<B, D>,
|
||||
) -> ADTensor<B, D> {
|
||||
#[derive(Debug)]
|
||||
struct IndexSelectDimAssign<const D2: usize>;
|
||||
struct IndexSelectDimAssign<const D: usize>;
|
||||
|
||||
impl<B: Backend, const D1: usize, const D2: usize> Backward<B, D1, 2> for IndexSelectDimAssign<D2> {
|
||||
type State = (usize, IntTensor<B, 1>, Shape<D1>, Shape<D2>, B::Device);
|
||||
impl<B: Backend, const D: usize> Backward<B, D, 2> for IndexSelectDimAssign<D> {
|
||||
type State = (usize, IntTensor<B, 1>, Shape<D>, Shape<D>, B::Device);
|
||||
|
||||
fn backward(self, ops: Ops<Self::State, 2>, grads: &mut Gradients) {
|
||||
let (dim, indexes, shape_lhs, shape_rhs, device) = ops.state;
|
||||
let [indexes_4lhs, indexes_4rhs] = duplicate(&ops.parents, Some(indexes));
|
||||
let (dim, indices, shape_lhs, shape_rhs, device) = ops.state;
|
||||
let [indices_4lhs, indices_4rhs] = duplicate(&ops.parents, Some(indices));
|
||||
|
||||
binary::<B, D1, D1, D2, _, _>(
|
||||
binary::<B, D, D, D, _, _>(
|
||||
ops.parents,
|
||||
ops.node,
|
||||
grads,
|
||||
|grad| {
|
||||
let zeros = B::zeros(shape_lhs, &device);
|
||||
B::index_select_assign(grad, dim, indexes_4lhs.unwrap(), zeros)
|
||||
B::select_assign(grad, dim, indices_4lhs.unwrap(), zeros)
|
||||
},
|
||||
|grad| {
|
||||
let zeros = B::zeros(shape_rhs, &device);
|
||||
B::index_select_assign(zeros, dim, indexes_4rhs.unwrap(), grad)
|
||||
B::select_assign(zeros, dim, indices_4rhs.unwrap(), grad)
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
match IndexSelectDimAssign::<D2>
|
||||
match IndexSelectDimAssign::<D>
|
||||
.prepare([tensor.node, value.node], [tensor.graph, value.graph])
|
||||
.statefull()
|
||||
{
|
||||
OpsKind::Tracked(prep) => prep.finish(
|
||||
(
|
||||
dim,
|
||||
indexes.clone(),
|
||||
indices.clone(),
|
||||
B::shape(&tensor.primitive),
|
||||
B::shape(&value.primitive),
|
||||
B::device(&value.primitive),
|
||||
),
|
||||
B::index_select_assign(tensor.primitive, dim, indexes, value.primitive),
|
||||
B::select_assign(tensor.primitive, dim, indices, value.primitive),
|
||||
),
|
||||
OpsKind::UnTracked(prep) => prep.finish(B::index_select_assign(
|
||||
OpsKind::UnTracked(prep) => prep.finish(B::select_assign(
|
||||
tensor.primitive,
|
||||
dim,
|
||||
indexes,
|
||||
indices,
|
||||
value.primitive,
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn index<const D1: usize, const D2: usize>(
|
||||
fn slice<const D1: usize, const D2: usize>(
|
||||
tensor: ADTensor<B, D1>,
|
||||
indexes: [std::ops::Range<usize>; D2],
|
||||
ranges: [std::ops::Range<usize>; D2],
|
||||
) -> ADTensor<B, D1> {
|
||||
#[derive(Debug)]
|
||||
struct Index<const D2: usize>;
|
||||
|
@ -631,11 +629,11 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
type State = ([std::ops::Range<usize>; D2], Shape<D1>, B::Device);
|
||||
|
||||
fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
|
||||
let (indexes, shape, device) = ops.state;
|
||||
let (ranges, shape, device) = ops.state;
|
||||
|
||||
unary::<B, D1, D1, _>(ops.parents, ops.node, grads, |grad| {
|
||||
let zeros = B::zeros(shape, &device);
|
||||
B::index_assign(zeros, indexes, grad)
|
||||
B::slice_assign(zeros, ranges, grad)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
@ -643,19 +641,19 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
match Index.prepare([tensor.node], [tensor.graph]).statefull() {
|
||||
OpsKind::Tracked(prep) => prep.finish(
|
||||
(
|
||||
indexes.clone(),
|
||||
ranges.clone(),
|
||||
B::shape(&tensor.primitive),
|
||||
B::device(&tensor.primitive),
|
||||
),
|
||||
B::index(tensor.primitive, indexes),
|
||||
B::slice(tensor.primitive, ranges),
|
||||
),
|
||||
OpsKind::UnTracked(prep) => prep.finish(B::index(tensor.primitive, indexes)),
|
||||
OpsKind::UnTracked(prep) => prep.finish(B::slice(tensor.primitive, ranges)),
|
||||
}
|
||||
}
|
||||
|
||||
fn index_assign<const D1: usize, const D2: usize>(
|
||||
fn slice_assign<const D1: usize, const D2: usize>(
|
||||
tensor: ADTensor<B, D1>,
|
||||
indexes: [std::ops::Range<usize>; D2],
|
||||
ranges: [std::ops::Range<usize>; D2],
|
||||
value: ADTensor<B, D1>,
|
||||
) -> ADTensor<B, D1> {
|
||||
#[derive(Debug)]
|
||||
|
@ -665,8 +663,8 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
type State = ([std::ops::Range<usize>; D2], Shape<D1>, B::Device);
|
||||
|
||||
fn backward(self, ops: Ops<Self::State, 2>, grads: &mut Gradients) {
|
||||
let (indexes, shape_rhs, device) = ops.state;
|
||||
let [indexes_4lhs, indexes_4rhs] = duplicate(&ops.parents, Some(indexes));
|
||||
let (ranges, shape_rhs, device) = ops.state;
|
||||
let [ranges_4lhs, ranges_4rhs] = duplicate(&ops.parents, Some(ranges));
|
||||
|
||||
binary::<B, D1, D1, D1, _, _>(
|
||||
ops.parents,
|
||||
|
@ -674,9 +672,9 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
grads,
|
||||
|grad| {
|
||||
let zeros = B::zeros(shape_rhs, &device);
|
||||
B::index_assign(grad, indexes_4lhs.unwrap(), zeros)
|
||||
B::slice_assign(grad, ranges_4lhs.unwrap(), zeros)
|
||||
},
|
||||
|grad| B::index(grad, indexes_4rhs.unwrap()),
|
||||
|grad| B::slice(grad, ranges_4rhs.unwrap()),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
@ -687,14 +685,14 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
{
|
||||
OpsKind::Tracked(prep) => prep.finish(
|
||||
(
|
||||
indexes.clone(),
|
||||
ranges.clone(),
|
||||
B::shape(&value.primitive),
|
||||
B::device(&value.primitive),
|
||||
),
|
||||
B::index_assign(tensor.primitive, indexes, value.primitive),
|
||||
B::slice_assign(tensor.primitive, ranges, value.primitive),
|
||||
),
|
||||
OpsKind::UnTracked(prep) => {
|
||||
prep.finish(B::index_assign(tensor.primitive, indexes, value.primitive))
|
||||
prep.finish(B::slice_assign(tensor.primitive, ranges, value.primitive))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1263,8 +1261,8 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
impl<B: Backend, const D: usize> Step for CatStep<B, D> {
|
||||
fn step(self: Box<Self>, grads: &mut Gradients) {
|
||||
let grad = grads.consume::<B, D>(&self.output);
|
||||
let indexes: Vec<_> = B::shape(&grad).dims.iter().map(|v| 0..*v).collect();
|
||||
let indexes: [std::ops::Range<usize>; D] = indexes.try_into().unwrap();
|
||||
let ranges: Vec<_> = B::shape(&grad).dims.iter().map(|v| 0..*v).collect();
|
||||
let ranges: [std::ops::Range<usize>; D] = ranges.try_into().unwrap();
|
||||
|
||||
let mut current_index = 0;
|
||||
|
||||
|
@ -1273,10 +1271,10 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
.zip(self.dim_sizes.into_iter())
|
||||
.filter_map(|(node, dim_size)| node.map(|node| (node, dim_size)))
|
||||
.for_each(|(node, dim_size)| {
|
||||
let mut indexes = indexes.clone();
|
||||
indexes[self.dim] = current_index..dim_size + current_index;
|
||||
let mut ranges = ranges.clone();
|
||||
ranges[self.dim] = current_index..dim_size + current_index;
|
||||
current_index += dim_size;
|
||||
grads.register::<B, D>(node, B::index(grad.clone(), indexes));
|
||||
grads.register::<B, D>(node, B::slice(grad.clone(), ranges));
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -1318,26 +1316,26 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
match MaxMinDim.prepare([tensor.node], [tensor.graph]).statefull() {
|
||||
OpsKind::Tracked(prep) => {
|
||||
let shape = B::shape(&tensor.primitive);
|
||||
let (tensor, index) = B::max_dim_with_indexes(tensor.primitive, dim);
|
||||
let (tensor, index) = B::max_dim_with_indices(tensor.primitive, dim);
|
||||
prep.finish((index, shape), tensor)
|
||||
}
|
||||
OpsKind::UnTracked(prep) => prep.finish(B::max_dim(tensor.primitive, dim)),
|
||||
}
|
||||
}
|
||||
fn max_dim_with_indexes<const D: usize>(
|
||||
fn max_dim_with_indices<const D: usize>(
|
||||
tensor: ADTensor<B, D>,
|
||||
dim: usize,
|
||||
) -> (ADTensor<B, D>, IntTensor<B, D>) {
|
||||
match MaxMinDim.prepare([tensor.node], [tensor.graph]).statefull() {
|
||||
OpsKind::Tracked(prep) => {
|
||||
let shape = B::shape(&tensor.primitive);
|
||||
let (tensor, index) = B::max_dim_with_indexes(tensor.primitive, dim);
|
||||
let (tensor, index) = B::max_dim_with_indices(tensor.primitive, dim);
|
||||
let tensor = prep.finish((index.clone(), shape), tensor);
|
||||
|
||||
(tensor, index)
|
||||
}
|
||||
OpsKind::UnTracked(prep) => {
|
||||
let (tensor, index) = B::max_dim_with_indexes(tensor.primitive, dim);
|
||||
let (tensor, index) = B::max_dim_with_indices(tensor.primitive, dim);
|
||||
let tensor = prep.finish(tensor);
|
||||
|
||||
(tensor, index)
|
||||
|
@ -1348,26 +1346,26 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
match MaxMinDim.prepare([tensor.node], [tensor.graph]).statefull() {
|
||||
OpsKind::Tracked(prep) => {
|
||||
let shape = B::shape(&tensor.primitive);
|
||||
let (tensor, index) = B::min_dim_with_indexes(tensor.primitive, dim);
|
||||
let (tensor, index) = B::min_dim_with_indices(tensor.primitive, dim);
|
||||
prep.finish((index, shape), tensor)
|
||||
}
|
||||
OpsKind::UnTracked(prep) => prep.finish(B::min_dim(tensor.primitive, dim)),
|
||||
}
|
||||
}
|
||||
fn min_dim_with_indexes<const D: usize>(
|
||||
fn min_dim_with_indices<const D: usize>(
|
||||
tensor: ADTensor<B, D>,
|
||||
dim: usize,
|
||||
) -> (ADTensor<B, D>, IntTensor<B, D>) {
|
||||
match MaxMinDim.prepare([tensor.node], [tensor.graph]).statefull() {
|
||||
OpsKind::Tracked(prep) => {
|
||||
let shape = B::shape(&tensor.primitive);
|
||||
let (tensor, index) = B::min_dim_with_indexes(tensor.primitive, dim);
|
||||
let (tensor, index) = B::min_dim_with_indices(tensor.primitive, dim);
|
||||
let tensor = prep.finish((index.clone(), shape), tensor);
|
||||
|
||||
(tensor, index)
|
||||
}
|
||||
OpsKind::UnTracked(prep) => {
|
||||
let (tensor, index) = B::min_dim_with_indexes(tensor.primitive, dim);
|
||||
let (tensor, index) = B::min_dim_with_indices(tensor.primitive, dim);
|
||||
let tensor = prep.finish(tensor);
|
||||
|
||||
(tensor, index)
|
||||
|
|
|
@ -6,16 +6,16 @@ mod tests {
|
|||
#[test]
|
||||
fn test_embedding_backward() {
|
||||
let weights = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let indexes = Data::from([[0, 1], [1, 1]]);
|
||||
let indices = Data::from([[0, 1], [1, 1]]);
|
||||
let x = Data::from([
|
||||
[[1.0, 2.0], [4.0, 5.0], [3.0, 4.0]],
|
||||
[[4.0, 5.0], [8.0, 5.0], [1.0, 9.0]],
|
||||
]);
|
||||
let weights = Tensor::<TestADBackend, 2>::from_data(weights).require_grad();
|
||||
let indexes = Tensor::<TestADBackend, 2, Int>::from_data(indexes);
|
||||
let indices = Tensor::<TestADBackend, 2, Int>::from_data(indices);
|
||||
let x = Tensor::<TestADBackend, 3>::from_data(x).require_grad();
|
||||
|
||||
let output = embedding(weights.clone(), indexes);
|
||||
let output = embedding(weights.clone(), indices);
|
||||
let output = output.matmul(x);
|
||||
let grads = output.backward();
|
||||
|
||||
|
|
|
@ -18,8 +18,8 @@ mod tests {
|
|||
let mut tensor_2_list = Vec::new();
|
||||
|
||||
for i in 0..2 {
|
||||
tensor_1_list.push(tensor_1.clone().index([i..i + 1]));
|
||||
tensor_2_list.push(tensor_2.clone().index([i..i + 1]));
|
||||
tensor_1_list.push(tensor_1.clone().slice([i..i + 1]));
|
||||
tensor_2_list.push(tensor_2.clone().slice([i..i + 1]));
|
||||
}
|
||||
|
||||
let tensor_1_cat = TestADTensor::cat(tensor_1_list.clone(), 0);
|
||||
|
@ -28,31 +28,31 @@ mod tests {
|
|||
let tensor_3_cat = tensor_1_cat.clone().matmul(tensor_2_cat.clone());
|
||||
let grads = tensor_3_cat.backward();
|
||||
|
||||
let grad_1_index_1 = tensor_1.grad(&grads).unwrap().index([0..1]);
|
||||
let grad_1_index_2 = tensor_1.grad(&grads).unwrap().index([1..2]);
|
||||
let grad_1_slice_1 = tensor_1.grad(&grads).unwrap().slice([0..1]);
|
||||
let grad_1_slice_2 = tensor_1.grad(&grads).unwrap().slice([1..2]);
|
||||
|
||||
let grad_2_index_1 = tensor_2.grad(&grads).unwrap().index([0..1]);
|
||||
let grad_2_index_2 = tensor_2.grad(&grads).unwrap().index([1..2]);
|
||||
let grad_2_slice_1 = tensor_2.grad(&grads).unwrap().slice([0..1]);
|
||||
let grad_2_slice_2 = tensor_2.grad(&grads).unwrap().slice([1..2]);
|
||||
|
||||
grad_1
|
||||
.clone()
|
||||
.index([0..1])
|
||||
.slice([0..1])
|
||||
.to_data()
|
||||
.assert_approx_eq(&grad_1_index_1.to_data(), 3);
|
||||
.assert_approx_eq(&grad_1_slice_1.to_data(), 3);
|
||||
grad_1
|
||||
.index([1..2])
|
||||
.slice([1..2])
|
||||
.to_data()
|
||||
.assert_approx_eq(&grad_1_index_2.to_data(), 3);
|
||||
.assert_approx_eq(&grad_1_slice_2.to_data(), 3);
|
||||
|
||||
grad_2
|
||||
.clone()
|
||||
.index([0..1])
|
||||
.slice([0..1])
|
||||
.to_data()
|
||||
.assert_approx_eq(&grad_2_index_1.to_data(), 3);
|
||||
.assert_approx_eq(&grad_2_slice_1.to_data(), 3);
|
||||
grad_2
|
||||
.index([1..2])
|
||||
.slice([1..2])
|
||||
.to_data()
|
||||
.assert_approx_eq(&grad_2_index_2.to_data(), 3);
|
||||
.assert_approx_eq(&grad_2_slice_2.to_data(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
@ -7,10 +7,10 @@ mod tests {
|
|||
fn test_gather_grad() {
|
||||
let tensor_1 =
|
||||
TestADTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])).require_grad();
|
||||
let indexes = TestADTensor::from_data(Data::from([[2, 1, 0, 1, 2], [1, 0, 2, 1, 0]]));
|
||||
let indices = TestADTensor::from_data(Data::from([[2, 1, 0, 1, 2], [1, 0, 2, 1, 0]]));
|
||||
|
||||
let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose());
|
||||
let tensor_3 = tensor_1.clone().gather(1, indexes);
|
||||
let tensor_3 = tensor_1.clone().gather(1, indices);
|
||||
let tensor_4 = tensor_2.matmul(tensor_3);
|
||||
|
||||
let grads = tensor_4.backward();
|
||||
|
@ -29,10 +29,10 @@ mod tests {
|
|||
TestADTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])).require_grad();
|
||||
let values =
|
||||
TestADTensor::from_data(Data::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])).require_grad();
|
||||
let indexes = TestADTensor::from_data(Data::from([[2, 1, 0], [2, 0, 1]]));
|
||||
let indices = TestADTensor::from_data(Data::from([[2, 1, 0], [2, 0, 1]]));
|
||||
|
||||
let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose());
|
||||
let tensor_3 = tensor_1.clone().scatter(1, indexes, values.clone());
|
||||
let tensor_3 = tensor_1.clone().scatter(1, indices, values.clone());
|
||||
let tensor_4 = tensor_2.matmul(tensor_3);
|
||||
|
||||
let grads = tensor_4.backward();
|
||||
|
|
|
@ -17,8 +17,6 @@ mod erf;
|
|||
mod exp;
|
||||
mod gather_scatter;
|
||||
mod gelu;
|
||||
mod index;
|
||||
mod index_select;
|
||||
mod log;
|
||||
mod log1p;
|
||||
mod mask;
|
||||
|
@ -31,7 +29,9 @@ mod neg;
|
|||
mod pow;
|
||||
mod relu;
|
||||
mod reshape;
|
||||
mod select;
|
||||
mod sin;
|
||||
mod slice;
|
||||
mod softmax;
|
||||
mod sqrt;
|
||||
mod sub;
|
||||
|
@ -71,9 +71,9 @@ macro_rules! testgen_all {
|
|||
burn_autodiff::testgen_ad_div!();
|
||||
burn_autodiff::testgen_ad_erf!();
|
||||
burn_autodiff::testgen_ad_exp!();
|
||||
burn_autodiff::testgen_ad_index!();
|
||||
burn_autodiff::testgen_ad_slice!();
|
||||
burn_autodiff::testgen_ad_gather_scatter!();
|
||||
burn_autodiff::testgen_ad_index_select!();
|
||||
burn_autodiff::testgen_ad_select!();
|
||||
burn_autodiff::testgen_ad_log!();
|
||||
burn_autodiff::testgen_ad_log1p!();
|
||||
burn_autodiff::testgen_ad_mask!();
|
||||
|
|
|
@ -1,16 +1,16 @@
|
|||
#[burn_tensor_testgen::testgen(ad_index_select)]
|
||||
#[burn_tensor_testgen::testgen(ad_select)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::Data;
|
||||
|
||||
#[test]
|
||||
fn test_index_select_grad() {
|
||||
fn test_select_grad() {
|
||||
let tensor_1 =
|
||||
TestADTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])).require_grad();
|
||||
let indexes = TestADTensor::from_data(Data::from([1, 0]));
|
||||
let indices = TestADTensor::from_data(Data::from([1, 0]));
|
||||
|
||||
let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose());
|
||||
let tensor_3 = tensor_1.clone().index_select(0, indexes);
|
||||
let tensor_3 = tensor_1.clone().select(0, indices);
|
||||
let tensor_4 = tensor_2.matmul(tensor_3);
|
||||
|
||||
let grads = tensor_4.backward();
|
||||
|
@ -24,17 +24,15 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn test_index_select_assign_grad() {
|
||||
fn test_select_assign_grad() {
|
||||
let tensor_1 =
|
||||
TestADTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])).require_grad();
|
||||
let values =
|
||||
TestADTensor::from_data(Data::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])).require_grad();
|
||||
let indexes = TestADTensor::from_data(Data::from([1, 0]));
|
||||
let indices = TestADTensor::from_data(Data::from([1, 0]));
|
||||
|
||||
let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose());
|
||||
let tensor_3 = tensor_1
|
||||
.clone()
|
||||
.index_select_assign(0, indexes, values.clone());
|
||||
let tensor_3 = tensor_1.clone().select_assign(0, indices, values.clone());
|
||||
let tensor_4 = tensor_2.matmul(tensor_3);
|
||||
|
||||
let grads = tensor_4.backward();
|
|
@ -1,17 +1,17 @@
|
|||
#[burn_tensor_testgen::testgen(ad_index)]
|
||||
#[burn_tensor_testgen::testgen(ad_slice)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::Data;
|
||||
|
||||
#[test]
|
||||
fn should_diff_matmul_with_index() {
|
||||
fn should_diff_matmul_with_slice() {
|
||||
let data_1: Data<f32, 2> = Data::from([[1.0, 7.0], [2.0, 3.0]]);
|
||||
let data_2: Data<f32, 2> = Data::from([[4.0, 7.0, 100.0], [2.0, 3.0, 15.0]]);
|
||||
|
||||
let tensor_1 = TestADTensor::from_data(data_1).require_grad();
|
||||
let tensor_2 = TestADTensor::from_data(data_2).require_grad();
|
||||
|
||||
let tensor_3 = tensor_2.clone().index([0..2, 0..2]);
|
||||
let tensor_3 = tensor_2.clone().slice([0..2, 0..2]);
|
||||
let tensor_4 = tensor_1.clone().matmul(tensor_3);
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
|
@ -26,7 +26,7 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_matmul_with_index_assign() {
|
||||
fn should_diff_matmul_with_slice_assign() {
|
||||
let data_1: Data<f32, 2> = Data::from([[1.0, 7.0], [2.0, 3.0]]);
|
||||
let data_2: Data<f32, 2> = Data::from([[4.0, 7.0], [2.0, 3.0]]);
|
||||
let data_assigned: Data<f32, 2> = Data::from([[9.0]]);
|
||||
|
@ -36,7 +36,7 @@ mod tests {
|
|||
let tensor_assigned = TestADTensor::from_data(data_assigned).require_grad();
|
||||
|
||||
let tensor_3 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let tensor_4 = tensor_3.index_assign([0..1, 0..1], tensor_assigned);
|
||||
let tensor_4 = tensor_3.slice_assign([0..1, 0..1], tensor_assigned);
|
||||
let tensor_5 = tensor_4.matmul(tensor_1.clone());
|
||||
|
||||
let grads = tensor_5.backward();
|
||||
|
@ -49,7 +49,7 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn should_diff_matmul_with_index_assign_complex() {
|
||||
fn should_diff_matmul_with_slice_assign_complex() {
|
||||
let data_1: Data<f32, 2> = Data::from([[1.0, 7.0], [2.0, 3.0]]);
|
||||
let data_2: Data<f32, 2> = Data::from([[4.0, 7.0], [2.0, 3.0]]);
|
||||
let data_3: Data<f32, 2> = Data::from([[9.0]]);
|
||||
|
@ -59,9 +59,9 @@ mod tests {
|
|||
let tensor_3 = TestADTensor::from_data(data_3).require_grad();
|
||||
|
||||
let tensor_4 = tensor_1.clone().matmul(tensor_2.clone());
|
||||
let tensor_5 = tensor_2.clone().index([0..1, 0..1]);
|
||||
let tensor_5 = tensor_2.clone().slice([0..1, 0..1]);
|
||||
let tensor_6 = tensor_5.mul(tensor_3.clone());
|
||||
let tensor_7 = tensor_4.index_assign([0..1, 0..1], tensor_6);
|
||||
let tensor_7 = tensor_4.slice_assign([0..1, 0..1], tensor_6);
|
||||
let tensor_8 = tensor_7.matmul(tensor_1.clone());
|
||||
|
||||
let grads = tensor_8.backward();
|
|
@ -14,7 +14,7 @@ pub fn generate_autoregressive_mask<B: Backend>(
|
|||
|
||||
for i in 0..(seq_length - 1) {
|
||||
let values = Tensor::<B, 3, Int>::ones([1, 1, seq_length - (i + 1)]);
|
||||
mask = mask.index_assign([0..1, i..i + 1, i + 1..seq_length], values);
|
||||
mask = mask.slice_assign([0..1, i..i + 1, i + 1..seq_length], values);
|
||||
}
|
||||
|
||||
mask = mask.to_device(device).repeat(0, batch_size);
|
||||
|
@ -68,7 +68,7 @@ pub fn generate_padding_mask<B: Backend>(
|
|||
}
|
||||
}
|
||||
|
||||
tensor = tensor.index_assign(
|
||||
tensor = tensor.slice_assign(
|
||||
[index..index + 1, 0..tokens.len()],
|
||||
Tensor::from_data(Data::new(
|
||||
tokens.into_iter().map(|e| (e as i64).elem()).collect(),
|
||||
|
|
|
@ -366,7 +366,7 @@ mod tests {
|
|||
|
||||
// Create a padding mask
|
||||
let mask_pad: Tensor<TestBackend, 2, Int> = Tensor::zeros([batch_size, seq_length]);
|
||||
let mask_pad = mask_pad.index_assign(
|
||||
let mask_pad = mask_pad.slice_assign(
|
||||
[0..batch_size, seq_length - num_padded..seq_length],
|
||||
Tensor::ones([batch_size, num_padded]),
|
||||
);
|
||||
|
@ -377,7 +377,7 @@ mod tests {
|
|||
Distribution::Standard,
|
||||
);
|
||||
// Change the end of the tensor
|
||||
let tensor_2 = tensor_1.clone().index_assign(
|
||||
let tensor_2 = tensor_1.clone().slice_assign(
|
||||
[
|
||||
0..batch_size,
|
||||
seq_length - num_padded..seq_length,
|
||||
|
@ -395,12 +395,12 @@ mod tests {
|
|||
// Check that the begginning of each tensor is the same
|
||||
output_1
|
||||
.context
|
||||
.index([0..batch_size, 0..seq_length - num_padded, 0..d_model])
|
||||
.slice([0..batch_size, 0..seq_length - num_padded, 0..d_model])
|
||||
.into_data()
|
||||
.assert_approx_eq(
|
||||
&output_2
|
||||
.context
|
||||
.index([0..batch_size, 0..seq_length - num_padded, 0..d_model])
|
||||
.slice([0..batch_size, 0..seq_length - num_padded, 0..d_model])
|
||||
.into_data(),
|
||||
3,
|
||||
);
|
||||
|
@ -423,9 +423,9 @@ mod tests {
|
|||
let mut cache = MhaCache::autoregressive();
|
||||
|
||||
for i in 1..seq_length + 1 {
|
||||
let tensor = tensor.clone().index([0..batch_size, 0..i, 0..d_model]);
|
||||
let tensor = tensor.clone().slice([0..batch_size, 0..i, 0..d_model]);
|
||||
let input = MhaInput::self_attn(tensor);
|
||||
let next_tok = mha.forward_cache(input, &mut cache).context.index([
|
||||
let next_tok = mha.forward_cache(input, &mut cache).context.slice([
|
||||
0..batch_size,
|
||||
i - 1..i,
|
||||
0..d_model,
|
||||
|
|
|
@ -21,7 +21,7 @@ impl<B: Backend, const D: usize> TensorCache<B, D> {
|
|||
CacheState::Value(tensor_old) => {
|
||||
let [batch_size, seq_length, d_model] = tensor.dims();
|
||||
let next_seq_token =
|
||||
tensor.index([0..batch_size, (seq_length - 1)..seq_length, 0..d_model]);
|
||||
tensor.slice([0..batch_size, (seq_length - 1)..seq_length, 0..d_model]);
|
||||
let next_seq_token = func(next_seq_token);
|
||||
|
||||
Tensor::cat(vec![tensor_old, next_seq_token], dim_cat)
|
||||
|
|
|
@ -120,14 +120,8 @@ impl<B: Backend> Gru<B> {
|
|||
|
||||
for t in 0..seq_length {
|
||||
let indices = Tensor::arange(t..t + 1);
|
||||
let input_t = batched_input
|
||||
.clone()
|
||||
.index_select(1, indices.clone())
|
||||
.squeeze(1);
|
||||
let hidden_t = hidden_state
|
||||
.clone()
|
||||
.index_select(1, indices.clone())
|
||||
.squeeze(1);
|
||||
let input_t = batched_input.clone().select(1, indices.clone()).squeeze(1);
|
||||
let hidden_t = hidden_state.clone().select(1, indices.clone()).squeeze(1);
|
||||
|
||||
// u(pdate)g(ate) tensors
|
||||
let biased_ug_input_sum = self.gate_product(&input_t, &hidden_t, &self.update_gate);
|
||||
|
@ -149,7 +143,7 @@ impl<B: Backend> Gru<B> {
|
|||
.mul(update_values.clone().sub_scalar(1).mul_scalar(-1)) // (1 - z(t)) = -(z(t) - 1)
|
||||
+ update_values.clone().mul(hidden_t);
|
||||
|
||||
hidden_state = hidden_state.index_assign(
|
||||
hidden_state = hidden_state.slice_assign(
|
||||
[0..self.batch_size, t..(t + 1), 0..self.d_hidden],
|
||||
state_vector.clone().unsqueeze(),
|
||||
);
|
||||
|
@ -268,7 +262,7 @@ mod tests {
|
|||
|
||||
let state = gru.forward(input, None);
|
||||
|
||||
let output = state.index_select(0, Tensor::arange(0..1)).squeeze(0);
|
||||
let output = state.select(0, Tensor::arange(0..1)).squeeze(0);
|
||||
|
||||
output.to_data().assert_approx_eq(&Data::from([[0.034]]), 3);
|
||||
}
|
||||
|
|
|
@ -141,7 +141,7 @@ impl<B: Backend> Lstm<B> {
|
|||
|
||||
for t in 0..seq_length {
|
||||
let indices = Tensor::arange(t..t + 1);
|
||||
let input_t = batched_input.clone().index_select(1, indices).squeeze(1);
|
||||
let input_t = batched_input.clone().select(1, indices).squeeze(1);
|
||||
// f(orget)g(ate) tensors
|
||||
let biased_fg_input_sum = self.gate_product(&input_t, &hidden_state, &self.forget_gate);
|
||||
let forget_values = activation::sigmoid(biased_fg_input_sum); // to multiply with cell state
|
||||
|
@ -162,11 +162,11 @@ impl<B: Backend> Lstm<B> {
|
|||
hidden_state = output_values * cell_state.clone().tanh();
|
||||
|
||||
// store the state for this timestep
|
||||
batched_cell_state = batched_cell_state.index_assign(
|
||||
batched_cell_state = batched_cell_state.slice_assign(
|
||||
[0..self.batch_size, t..(t + 1), 0..self.d_hidden],
|
||||
cell_state.clone().unsqueeze(),
|
||||
);
|
||||
batched_hidden_state = batched_hidden_state.index_assign(
|
||||
batched_hidden_state = batched_hidden_state.slice_assign(
|
||||
[0..self.batch_size, t..(t + 1), 0..self.d_hidden],
|
||||
hidden_state.clone().unsqueeze(),
|
||||
);
|
||||
|
@ -312,11 +312,9 @@ mod tests {
|
|||
let input = Tensor::<TestBackend, 3>::from_data(Data::from([[[0.1]]]));
|
||||
|
||||
let (cell_state_batch, hidden_state_batch) = lstm.forward(input, None);
|
||||
let cell_state = cell_state_batch
|
||||
.index_select(0, Tensor::arange(0..1))
|
||||
.squeeze(0);
|
||||
let cell_state = cell_state_batch.select(0, Tensor::arange(0..1)).squeeze(0);
|
||||
let hidden_state = hidden_state_batch
|
||||
.index_select(0, Tensor::arange(0..1))
|
||||
.select(0, Tensor::arange(0..1))
|
||||
.squeeze(0);
|
||||
cell_state
|
||||
.to_data()
|
||||
|
|
|
@ -421,14 +421,14 @@ mod tests {
|
|||
let mut cache = transformer.new_autoregressive_cache();
|
||||
|
||||
for i in 1..seq_length + 1 {
|
||||
let target = target.clone().index([0..batch_size, 0..i, 0..d_model]);
|
||||
let target = target.clone().slice([0..batch_size, 0..i, 0..d_model]);
|
||||
|
||||
let mask_attn = generate_autoregressive_mask(batch_size, i, &target.device());
|
||||
let input = TransformerDecoderInput::new(target.clone(), memory.clone())
|
||||
.target_mask_attn(mask_attn);
|
||||
let next_tok = transformer // Greedy sampling
|
||||
.forward_autoregressive_inference(input, &mut cache)
|
||||
.index([0..batch_size, i - 1..i, 0..d_model]);
|
||||
.slice([0..batch_size, i - 1..i, 0..d_model]);
|
||||
output_2.push(next_tok);
|
||||
}
|
||||
|
||||
|
|
|
@ -359,11 +359,11 @@ mod tests {
|
|||
let mut cache = transformer.new_autoregressive_cache();
|
||||
|
||||
for i in 1..seq_length + 1 {
|
||||
let tensor = tensor.clone().index([0..batch_size, 0..i, 0..d_model]);
|
||||
let tensor = tensor.clone().slice([0..batch_size, 0..i, 0..d_model]);
|
||||
let input = TransformerEncoderInput::new(tensor.clone());
|
||||
let next_tok = transformer
|
||||
.forward_autoregressive_inference(input, &mut cache)
|
||||
.index([0..batch_size, i - 1..i, 0..d_model]);
|
||||
.slice([0..batch_size, i - 1..i, 0..d_model]);
|
||||
output_2.push(next_tok);
|
||||
}
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ use std::marker::PhantomData;
|
|||
/// want a probability distribution that is computed lazily.
|
||||
pub struct ShuffledDataset<D, I> {
|
||||
dataset: D,
|
||||
indexes: Vec<usize>,
|
||||
indices: Vec<usize>,
|
||||
input: PhantomData<I>,
|
||||
}
|
||||
|
||||
|
@ -16,15 +16,15 @@ where
|
|||
{
|
||||
/// Creates a new shuffled dataset.
|
||||
pub fn new(dataset: D, rng: &mut StdRng) -> Self {
|
||||
let mut indexes = Vec::with_capacity(dataset.len());
|
||||
let mut indices = Vec::with_capacity(dataset.len());
|
||||
for i in 0..dataset.len() {
|
||||
indexes.push(i);
|
||||
indices.push(i);
|
||||
}
|
||||
indexes.shuffle(rng);
|
||||
indices.shuffle(rng);
|
||||
|
||||
Self {
|
||||
dataset,
|
||||
indexes,
|
||||
indices,
|
||||
input: PhantomData,
|
||||
}
|
||||
}
|
||||
|
@ -42,7 +42,7 @@ where
|
|||
I: Clone + Send + Sync,
|
||||
{
|
||||
fn get(&self, index: usize) -> Option<I> {
|
||||
let index = match self.indexes.get(index) {
|
||||
let index = match self.indices.get(index) {
|
||||
Some(index) => index,
|
||||
None => return None,
|
||||
};
|
||||
|
|
|
@ -28,22 +28,22 @@ impl<E> NdArrayOps<E>
|
|||
where
|
||||
E: Copy,
|
||||
{
|
||||
pub fn index<const D1: usize, const D2: usize>(
|
||||
pub fn slice<const D1: usize, const D2: usize>(
|
||||
tensor: NdArrayTensor<E, D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
ranges: [Range<usize>; D2],
|
||||
) -> NdArrayTensor<E, D1> {
|
||||
let slices = Self::to_slice_args::<D1, D2>(indexes);
|
||||
let slices = Self::to_slice_args::<D1, D2>(ranges);
|
||||
let array = tensor.array.slice_move(slices.as_slice()).into_shared();
|
||||
|
||||
NdArrayTensor { array }
|
||||
}
|
||||
|
||||
pub fn index_assign<const D1: usize, const D2: usize>(
|
||||
pub fn slice_assign<const D1: usize, const D2: usize>(
|
||||
tensor: NdArrayTensor<E, D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
ranges: [Range<usize>; D2],
|
||||
value: NdArrayTensor<E, D1>,
|
||||
) -> NdArrayTensor<E, D1> {
|
||||
let slices = Self::to_slice_args::<D1, D2>(indexes);
|
||||
let slices = Self::to_slice_args::<D1, D2>(ranges);
|
||||
let mut array = tensor.array.into_owned();
|
||||
array.slice_mut(slices.as_slice()).assign(&value.array);
|
||||
let array = array.into_shared();
|
||||
|
@ -77,7 +77,7 @@ where
|
|||
}
|
||||
|
||||
fn to_slice_args<const D1: usize, const D2: usize>(
|
||||
indexes: [Range<usize>; D2],
|
||||
ranges: [Range<usize>; D2],
|
||||
) -> [SliceInfoElem; D1] {
|
||||
let mut slices = [SliceInfoElem::NewAxis; D1];
|
||||
for i in 0..D1 {
|
||||
|
@ -89,8 +89,8 @@ where
|
|||
}
|
||||
} else {
|
||||
slices[i] = SliceInfoElem::Slice {
|
||||
start: indexes[i].start as isize,
|
||||
end: Some(indexes[i].end as isize),
|
||||
start: ranges[i].start as isize,
|
||||
end: Some(ranges[i].end as isize),
|
||||
step: 1,
|
||||
}
|
||||
}
|
||||
|
@ -211,31 +211,31 @@ where
|
|||
pub fn gather<const D: usize>(
|
||||
dim: usize,
|
||||
mut tensor: NdArrayTensor<E, D>,
|
||||
mut indexes: NdArrayTensor<i64, D>,
|
||||
mut indices: NdArrayTensor<i64, D>,
|
||||
) -> NdArrayTensor<E, D> {
|
||||
if dim != D - 1 {
|
||||
tensor.array.swap_axes(D - 1, dim);
|
||||
indexes.array.swap_axes(D - 1, dim);
|
||||
indices.array.swap_axes(D - 1, dim);
|
||||
}
|
||||
let (shape_tensor, shape_indexes) = (tensor.shape(), indexes.shape());
|
||||
let (size_tensor, size_index) = (shape_tensor.dims[D - 1], shape_indexes.dims[D - 1]);
|
||||
let batch_size = Self::gather_batch_size(&shape_tensor, &shape_indexes);
|
||||
let (shape_tensor, shape_indices) = (tensor.shape(), indices.shape());
|
||||
let (size_tensor, size_index) = (shape_tensor.dims[D - 1], shape_indices.dims[D - 1]);
|
||||
let batch_size = Self::gather_batch_size(&shape_tensor, &shape_indices);
|
||||
|
||||
let indexes = NdArrayOps::reshape(indexes, Shape::new([batch_size, size_index])).array;
|
||||
let indices = NdArrayOps::reshape(indices, Shape::new([batch_size, size_index])).array;
|
||||
let tensor = NdArrayOps::reshape(tensor, Shape::new([batch_size, size_tensor])).array;
|
||||
let mut output = Array2::zeros((batch_size, size_index));
|
||||
|
||||
for b in 0..batch_size {
|
||||
let indexes = indexes.slice(s!(b, ..));
|
||||
let indices = indices.slice(s!(b, ..));
|
||||
|
||||
for (i, index) in indexes.iter().enumerate() {
|
||||
for (i, index) in indices.iter().enumerate() {
|
||||
output[[b, i]] = tensor[[b, *index as usize]];
|
||||
}
|
||||
}
|
||||
|
||||
let mut output = NdArrayOps::reshape(
|
||||
NdArrayTensor::<E, 2>::new(output.into_shared().into_dyn()),
|
||||
shape_indexes,
|
||||
shape_indices,
|
||||
);
|
||||
|
||||
if dim != D - 1 {
|
||||
|
@ -248,36 +248,36 @@ where
|
|||
pub fn scatter<const D: usize>(
|
||||
dim: usize,
|
||||
mut tensor: NdArrayTensor<E, D>,
|
||||
mut indexes: NdArrayTensor<i64, D>,
|
||||
mut indices: NdArrayTensor<i64, D>,
|
||||
mut value: NdArrayTensor<E, D>,
|
||||
) -> NdArrayTensor<E, D> {
|
||||
if dim != D - 1 {
|
||||
tensor.array.swap_axes(D - 1, dim);
|
||||
indexes.array.swap_axes(D - 1, dim);
|
||||
indices.array.swap_axes(D - 1, dim);
|
||||
value.array.swap_axes(D - 1, dim);
|
||||
}
|
||||
|
||||
let (shape_tensor, shape_indexes, shape_value) =
|
||||
(tensor.shape(), indexes.shape(), value.shape());
|
||||
let (shape_tensor, shape_indices, shape_value) =
|
||||
(tensor.shape(), indices.shape(), value.shape());
|
||||
let (size_tensor, size_index, size_value) = (
|
||||
shape_tensor.dims[D - 1],
|
||||
shape_indexes.dims[D - 1],
|
||||
shape_indices.dims[D - 1],
|
||||
shape_value.dims[D - 1],
|
||||
);
|
||||
let batch_size = Self::gather_batch_size(&shape_tensor, &shape_indexes);
|
||||
let batch_size = Self::gather_batch_size(&shape_tensor, &shape_indices);
|
||||
|
||||
if shape_value != shape_indexes {
|
||||
panic!("Invalid dimension: the shape of the index tensor should be the same as the value tensor: Index {:?} value {:?}", shape_indexes.dims, shape_value.dims);
|
||||
if shape_value != shape_indices {
|
||||
panic!("Invalid dimension: the shape of the index tensor should be the same as the value tensor: Index {:?} value {:?}", shape_indices.dims, shape_value.dims);
|
||||
}
|
||||
|
||||
let indexes = NdArrayOps::reshape(indexes, Shape::new([batch_size, size_index])).array;
|
||||
let indices = NdArrayOps::reshape(indices, Shape::new([batch_size, size_index])).array;
|
||||
let value = NdArrayOps::reshape(value, Shape::new([batch_size, size_value])).array;
|
||||
let mut tensor = NdArrayOps::reshape(tensor, Shape::new([batch_size, size_tensor])).array;
|
||||
|
||||
for b in 0..batch_size {
|
||||
let indexes = indexes.slice(s!(b, ..));
|
||||
let indices = indices.slice(s!(b, ..));
|
||||
|
||||
for (i, index) in indexes.iter().enumerate() {
|
||||
for (i, index) in indices.iter().enumerate() {
|
||||
let index = *index as usize;
|
||||
tensor[[b, index]] += value[[b, i]];
|
||||
}
|
||||
|
@ -331,28 +331,28 @@ where
|
|||
|
||||
fn gather_batch_size<const D: usize>(
|
||||
shape_tensor: &Shape<D>,
|
||||
shape_indexes: &Shape<D>,
|
||||
shape_indices: &Shape<D>,
|
||||
) -> usize {
|
||||
let mut batch_size = 1;
|
||||
|
||||
for i in 0..D - 1 {
|
||||
if shape_tensor.dims[i] != shape_indexes.dims[i] {
|
||||
panic!("Unsupported dimension, only the last dimension can differ: Tensor {:?} Index {:?}", shape_tensor.dims, shape_indexes.dims);
|
||||
if shape_tensor.dims[i] != shape_indices.dims[i] {
|
||||
panic!("Unsupported dimension, only the last dimension can differ: Tensor {:?} Index {:?}", shape_tensor.dims, shape_indices.dims);
|
||||
}
|
||||
batch_size *= shape_indexes.dims[i];
|
||||
batch_size *= shape_indices.dims[i];
|
||||
}
|
||||
|
||||
batch_size
|
||||
}
|
||||
|
||||
pub fn index_select<const D: usize>(
|
||||
pub fn select<const D: usize>(
|
||||
tensor: NdArrayTensor<E, D>,
|
||||
dim: usize,
|
||||
indexes: NdArrayTensor<i64, 1>,
|
||||
indices: NdArrayTensor<i64, 1>,
|
||||
) -> NdArrayTensor<E, D> {
|
||||
let array = tensor.array.select(
|
||||
Axis(dim),
|
||||
&indexes
|
||||
&indices
|
||||
.array
|
||||
.into_iter()
|
||||
.map(|i| i as usize)
|
||||
|
@ -362,15 +362,15 @@ where
|
|||
NdArrayTensor::new(array.into_shared())
|
||||
}
|
||||
|
||||
pub fn index_select_assign<const D1: usize, const D2: usize>(
|
||||
pub fn select_assign<const D1: usize, const D2: usize>(
|
||||
tensor: NdArrayTensor<E, D1>,
|
||||
dim: usize,
|
||||
indexes: NdArrayTensor<i64, 1>,
|
||||
indices: NdArrayTensor<i64, 1>,
|
||||
value: NdArrayTensor<E, D2>,
|
||||
) -> NdArrayTensor<E, D1> {
|
||||
let mut output_array = tensor.array.into_owned();
|
||||
|
||||
for (index_value, index) in indexes.array.into_iter().enumerate() {
|
||||
for (index_value, index) in indices.array.into_iter().enumerate() {
|
||||
let mut view = output_array.index_axis_mut(Axis(dim), index as usize);
|
||||
let value = value.array.index_axis(Axis(dim), index_value);
|
||||
|
||||
|
|
|
@ -57,11 +57,11 @@ impl<E: FloatNdArrayElement> BoolTensorOps<NdArrayBackend<E>> for NdArrayBackend
|
|||
NdArrayOps::reshape(tensor, shape)
|
||||
}
|
||||
|
||||
fn bool_index<const D1: usize, const D2: usize>(
|
||||
fn bool_slice<const D1: usize, const D2: usize>(
|
||||
tensor: NdArrayTensor<bool, D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
ranges: [Range<usize>; D2],
|
||||
) -> NdArrayTensor<bool, D1> {
|
||||
NdArrayOps::index(tensor, indexes)
|
||||
NdArrayOps::slice(tensor, ranges)
|
||||
}
|
||||
|
||||
fn bool_into_int<const D: usize>(
|
||||
|
@ -85,12 +85,12 @@ impl<E: FloatNdArrayElement> BoolTensorOps<NdArrayBackend<E>> for NdArrayBackend
|
|||
NdArrayTensor::from_data(Data::new(values, shape))
|
||||
}
|
||||
|
||||
fn bool_index_assign<const D1: usize, const D2: usize>(
|
||||
fn bool_slice_assign<const D1: usize, const D2: usize>(
|
||||
tensor: <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
ranges: [Range<usize>; D2],
|
||||
value: <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D1>,
|
||||
) -> <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D1> {
|
||||
NdArrayOps::index_assign(tensor, indexes, value)
|
||||
NdArrayOps::slice_assign(tensor, ranges, value)
|
||||
}
|
||||
|
||||
fn bool_cat<const D: usize>(
|
||||
|
|
|
@ -51,11 +51,11 @@ impl<E: FloatNdArrayElement> IntTensorOps<NdArrayBackend<E>> for NdArrayBackend<
|
|||
NdArrayOps::reshape(tensor, shape)
|
||||
}
|
||||
|
||||
fn int_index<const D1: usize, const D2: usize>(
|
||||
fn int_slice<const D1: usize, const D2: usize>(
|
||||
tensor: NdArrayTensor<i64, D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
ranges: [Range<usize>; D2],
|
||||
) -> NdArrayTensor<i64, D1> {
|
||||
NdArrayOps::index(tensor, indexes)
|
||||
NdArrayOps::slice(tensor, ranges)
|
||||
}
|
||||
|
||||
fn int_device<const D: usize>(
|
||||
|
@ -88,12 +88,12 @@ impl<E: FloatNdArrayElement> IntTensorOps<NdArrayBackend<E>> for NdArrayBackend<
|
|||
NdArrayMathOps::mask_fill(tensor, mask, value)
|
||||
}
|
||||
|
||||
fn int_index_assign<const D1: usize, const D2: usize>(
|
||||
fn int_slice_assign<const D1: usize, const D2: usize>(
|
||||
tensor: NdArrayTensor<i64, D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
ranges: [Range<usize>; D2],
|
||||
value: NdArrayTensor<i64, D1>,
|
||||
) -> NdArrayTensor<i64, D1> {
|
||||
NdArrayOps::index_assign(tensor, indexes, value)
|
||||
NdArrayOps::slice_assign(tensor, ranges, value)
|
||||
}
|
||||
|
||||
fn int_cat<const D: usize>(
|
||||
|
@ -283,35 +283,35 @@ impl<E: FloatNdArrayElement> IntTensorOps<NdArrayBackend<E>> for NdArrayBackend<
|
|||
fn int_gather<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: NdArrayTensor<i64, D>,
|
||||
indexes: NdArrayTensor<i64, D>,
|
||||
indices: NdArrayTensor<i64, D>,
|
||||
) -> NdArrayTensor<i64, D> {
|
||||
NdArrayMathOps::gather(dim, tensor, indexes)
|
||||
NdArrayMathOps::gather(dim, tensor, indices)
|
||||
}
|
||||
|
||||
fn int_scatter<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: NdArrayTensor<i64, D>,
|
||||
indexes: NdArrayTensor<i64, D>,
|
||||
indices: NdArrayTensor<i64, D>,
|
||||
value: NdArrayTensor<i64, D>,
|
||||
) -> NdArrayTensor<i64, D> {
|
||||
NdArrayMathOps::scatter(dim, tensor, indexes, value)
|
||||
NdArrayMathOps::scatter(dim, tensor, indices, value)
|
||||
}
|
||||
|
||||
fn int_index_select_dim<const D: usize>(
|
||||
fn int_select<const D: usize>(
|
||||
tensor: NdArrayTensor<i64, D>,
|
||||
dim: usize,
|
||||
indexes: NdArrayTensor<i64, 1>,
|
||||
indices: NdArrayTensor<i64, 1>,
|
||||
) -> NdArrayTensor<i64, D> {
|
||||
NdArrayMathOps::index_select(tensor, dim, indexes)
|
||||
NdArrayMathOps::select(tensor, dim, indices)
|
||||
}
|
||||
|
||||
fn int_index_select_dim_assign<const D1: usize, const D2: usize>(
|
||||
tensor: NdArrayTensor<i64, D1>,
|
||||
fn int_select_assign<const D: usize>(
|
||||
tensor: NdArrayTensor<i64, D>,
|
||||
dim: usize,
|
||||
indexes: NdArrayTensor<i64, 1>,
|
||||
value: NdArrayTensor<i64, D2>,
|
||||
) -> NdArrayTensor<i64, D1> {
|
||||
NdArrayMathOps::index_select_assign(tensor, dim, indexes, value)
|
||||
indices: NdArrayTensor<i64, 1>,
|
||||
value: NdArrayTensor<i64, D>,
|
||||
) -> NdArrayTensor<i64, D> {
|
||||
NdArrayMathOps::select_assign(tensor, dim, indices, value)
|
||||
}
|
||||
fn int_argmax<const D: usize>(
|
||||
tensor: NdArrayTensor<i64, D>,
|
||||
|
|
|
@ -60,7 +60,7 @@ pub(crate) fn max_pool2d<E: FloatNdArrayElement>(
|
|||
NdArrayTensor::new(output.into_dyn().into_shared())
|
||||
}
|
||||
|
||||
pub(crate) fn max_pool2d_with_indexes<E: FloatNdArrayElement>(
|
||||
pub(crate) fn max_pool2d_with_indices<E: FloatNdArrayElement>(
|
||||
x: NdArrayTensor<E, 4>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
|
@ -78,10 +78,10 @@ pub(crate) fn max_pool2d_with_indexes<E: FloatNdArrayElement>(
|
|||
let x = apply_padding_4d(x, padding, inf).array;
|
||||
|
||||
let mut output = Array4::from_elem((batch_size, channels, out_height, out_width), inf);
|
||||
let mut indexes = Array4::<i64>::zeros((batch_size, channels, out_height, out_width));
|
||||
let mut indices = Array4::<i64>::zeros((batch_size, channels, out_height, out_width));
|
||||
|
||||
let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
|
||||
let unsafe_shared_indexes = UnsafeSharedRef::new(&mut indexes);
|
||||
let unsafe_shared_indices = UnsafeSharedRef::new(&mut indices);
|
||||
|
||||
run_par!(|| {
|
||||
iter_par!(0, batch_size * channels).for_each(|k| unsafe {
|
||||
|
@ -89,7 +89,7 @@ pub(crate) fn max_pool2d_with_indexes<E: FloatNdArrayElement>(
|
|||
let c = k % channels;
|
||||
|
||||
let output = unsafe_shared_out.get();
|
||||
let indexes = unsafe_shared_indexes.get();
|
||||
let indices = unsafe_shared_indices.get();
|
||||
|
||||
for oh in 0..out_height {
|
||||
for ow in 0..out_width {
|
||||
|
@ -115,16 +115,16 @@ pub(crate) fn max_pool2d_with_indexes<E: FloatNdArrayElement>(
|
|||
}
|
||||
|
||||
output[[b, c, oh, ow]] = max_val;
|
||||
indexes[[b, c, oh, ow]] = index;
|
||||
indices[[b, c, oh, ow]] = index;
|
||||
}
|
||||
}
|
||||
})
|
||||
});
|
||||
|
||||
let output = NdArrayTensor::new(output.into_dyn().into_shared());
|
||||
let indexes = NdArrayTensor::new(indexes.into_dyn().into_shared());
|
||||
let indices = NdArrayTensor::new(indices.into_dyn().into_shared());
|
||||
|
||||
(output, indexes)
|
||||
(output, indices)
|
||||
}
|
||||
|
||||
pub(crate) fn max_pool2d_backward<E: FloatNdArrayElement>(
|
||||
|
@ -133,13 +133,13 @@ pub(crate) fn max_pool2d_backward<E: FloatNdArrayElement>(
|
|||
_stride: [usize; 2],
|
||||
_padding: [usize; 2],
|
||||
output_grad: NdArrayTensor<E, 4>,
|
||||
indexes: NdArrayTensor<i64, 4>,
|
||||
indices: NdArrayTensor<i64, 4>,
|
||||
) -> NdArrayTensor<E, 4> {
|
||||
let [_batch_size, _channels, height, width] = output_grad.shape().dims;
|
||||
let [batch_size, channels, height_x, width_x] = x.shape().dims;
|
||||
|
||||
let output_grad = output_grad.array;
|
||||
let indexes = indexes.array;
|
||||
let indices = indices.array;
|
||||
|
||||
let mut output = Array4::zeros((batch_size, channels, height_x, width_x));
|
||||
|
||||
|
@ -154,7 +154,7 @@ pub(crate) fn max_pool2d_backward<E: FloatNdArrayElement>(
|
|||
|
||||
for h in 0..height {
|
||||
for w in 0..width {
|
||||
let index = indexes[[b, c, h, w]];
|
||||
let index = indices[[b, c, h, w]];
|
||||
let grad = output_grad[[b, c, h, w]];
|
||||
|
||||
let index_h = index as usize / width_x;
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use super::{
|
||||
avgpool::{avg_pool2d, avg_pool2d_backward},
|
||||
conv::{conv2d, conv_transpose2d},
|
||||
maxpool::{max_pool2d, max_pool2d_backward, max_pool2d_with_indexes},
|
||||
maxpool::{max_pool2d, max_pool2d_backward, max_pool2d_with_indices},
|
||||
};
|
||||
use crate::{element::FloatNdArrayElement, tensor::NdArrayTensor, NdArrayBackend};
|
||||
use burn_tensor::ops::*;
|
||||
|
@ -53,24 +53,24 @@ impl<E: FloatNdArrayElement> ModuleOps<NdArrayBackend<E>> for NdArrayBackend<E>
|
|||
max_pool2d(x, kernel_size, stride, padding)
|
||||
}
|
||||
|
||||
fn max_pool2d_with_indexes(
|
||||
fn max_pool2d_with_indices(
|
||||
x: NdArrayTensor<E, 4>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
) -> MaxPool2dWithIndexes<NdArrayBackend<E>> {
|
||||
let (output, indexes) = max_pool2d_with_indexes(x, kernel_size, stride, padding);
|
||||
) -> MaxPool2dWithIndices<NdArrayBackend<E>> {
|
||||
let (output, indices) = max_pool2d_with_indices(x, kernel_size, stride, padding);
|
||||
|
||||
MaxPool2dWithIndexes::new(output, indexes)
|
||||
MaxPool2dWithIndices::new(output, indices)
|
||||
}
|
||||
|
||||
fn max_pool2d_with_indexes_backward(
|
||||
fn max_pool2d_with_indices_backward(
|
||||
x: NdArrayTensor<E, 4>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
output_grad: NdArrayTensor<E, 4>,
|
||||
indexes: NdArrayTensor<i64, 4>,
|
||||
indices: NdArrayTensor<i64, 4>,
|
||||
) -> MaxPool2dBackward<NdArrayBackend<E>> {
|
||||
MaxPool2dBackward::new(max_pool2d_backward(
|
||||
x,
|
||||
|
@ -78,7 +78,7 @@ impl<E: FloatNdArrayElement> ModuleOps<NdArrayBackend<E>> for NdArrayBackend<E>
|
|||
stride,
|
||||
padding,
|
||||
output_grad,
|
||||
indexes,
|
||||
indices,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,7 +18,7 @@ pub(crate) fn apply_padding_4d<E: FloatNdArrayElement>(
|
|||
);
|
||||
let mut x_new = NdArrayTensor::new(x_new.into_shared().into_dyn());
|
||||
|
||||
x_new = NdArrayBackend::index_assign(
|
||||
x_new = NdArrayBackend::slice_assign(
|
||||
x_new,
|
||||
[
|
||||
0..batch_size,
|
||||
|
|
|
@ -154,50 +154,50 @@ impl<E: FloatNdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E>
|
|||
fn gather<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: NdArrayTensor<E, D>,
|
||||
indexes: NdArrayTensor<i64, D>,
|
||||
indices: NdArrayTensor<i64, D>,
|
||||
) -> NdArrayTensor<E, D> {
|
||||
NdArrayMathOps::gather(dim, tensor, indexes)
|
||||
NdArrayMathOps::gather(dim, tensor, indices)
|
||||
}
|
||||
|
||||
fn scatter<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: NdArrayTensor<E, D>,
|
||||
indexes: NdArrayTensor<i64, D>,
|
||||
indices: NdArrayTensor<i64, D>,
|
||||
value: NdArrayTensor<E, D>,
|
||||
) -> NdArrayTensor<E, D> {
|
||||
NdArrayMathOps::scatter(dim, tensor, indexes, value)
|
||||
NdArrayMathOps::scatter(dim, tensor, indices, value)
|
||||
}
|
||||
|
||||
fn index_select<const D: usize>(
|
||||
fn select<const D: usize>(
|
||||
tensor: NdArrayTensor<E, D>,
|
||||
dim: usize,
|
||||
indexes: NdArrayTensor<i64, 1>,
|
||||
indices: NdArrayTensor<i64, 1>,
|
||||
) -> NdArrayTensor<E, D> {
|
||||
NdArrayMathOps::index_select(tensor, dim, indexes)
|
||||
NdArrayMathOps::select(tensor, dim, indices)
|
||||
}
|
||||
|
||||
fn index_select_assign<const D1: usize, const D2: usize>(
|
||||
tensor: NdArrayTensor<E, D1>,
|
||||
fn select_assign<const D: usize>(
|
||||
tensor: NdArrayTensor<E, D>,
|
||||
dim: usize,
|
||||
indexes: NdArrayTensor<i64, 1>,
|
||||
value: NdArrayTensor<E, D2>,
|
||||
) -> NdArrayTensor<E, D1> {
|
||||
NdArrayMathOps::index_select_assign(tensor, dim, indexes, value)
|
||||
indices: NdArrayTensor<i64, 1>,
|
||||
value: NdArrayTensor<E, D>,
|
||||
) -> NdArrayTensor<E, D> {
|
||||
NdArrayMathOps::select_assign(tensor, dim, indices, value)
|
||||
}
|
||||
|
||||
fn index<const D1: usize, const D2: usize>(
|
||||
fn slice<const D1: usize, const D2: usize>(
|
||||
tensor: NdArrayTensor<E, D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
ranges: [Range<usize>; D2],
|
||||
) -> NdArrayTensor<E, D1> {
|
||||
NdArrayOps::index(tensor, indexes)
|
||||
NdArrayOps::slice(tensor, ranges)
|
||||
}
|
||||
|
||||
fn index_assign<const D1: usize, const D2: usize>(
|
||||
fn slice_assign<const D1: usize, const D2: usize>(
|
||||
tensor: NdArrayTensor<E, D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
ranges: [Range<usize>; D2],
|
||||
value: NdArrayTensor<E, D1>,
|
||||
) -> NdArrayTensor<E, D1> {
|
||||
NdArrayOps::index_assign(tensor, indexes, value)
|
||||
NdArrayOps::slice_assign(tensor, ranges, value)
|
||||
}
|
||||
|
||||
fn mask_where<const D: usize>(
|
||||
|
|
|
@ -18,14 +18,14 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
|
|||
TchTensor::from_existing(tensor.tensor.reshape(shape_tch.dims), tensor.storage)
|
||||
}
|
||||
|
||||
pub fn index<const D1: usize, const D2: usize>(
|
||||
pub fn slice<const D1: usize, const D2: usize>(
|
||||
tensor: TchTensor<E, D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
ranges: [Range<usize>; D2],
|
||||
) -> TchTensor<E, D1> {
|
||||
let storage = tensor.storage.clone();
|
||||
let mut tensor = tensor.tensor.shallow_clone();
|
||||
|
||||
for (i, index) in indexes.iter().enumerate().take(D2) {
|
||||
for (i, index) in ranges.iter().enumerate().take(D2) {
|
||||
let start = index.start as i64;
|
||||
let length = (index.end - index.start) as i64;
|
||||
tensor = tensor.narrow(i as i64, start, length);
|
||||
|
@ -34,9 +34,9 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
|
|||
TchTensor::from_existing(tensor, storage)
|
||||
}
|
||||
|
||||
pub fn index_assign<const D1: usize, const D2: usize>(
|
||||
pub fn slice_assign<const D1: usize, const D2: usize>(
|
||||
tensor: TchTensor<E, D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
ranges: [Range<usize>; D2],
|
||||
value: TchTensor<E, D1>,
|
||||
) -> TchTensor<E, D1> {
|
||||
let tensor_original = tensor.tensor.copy();
|
||||
|
@ -44,7 +44,7 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
|
|||
|
||||
let mut tensor = tensor_original.view_(tch_shape.dims);
|
||||
|
||||
for (i, index) in indexes.into_iter().enumerate().take(D2) {
|
||||
for (i, index) in ranges.into_iter().enumerate().take(D2) {
|
||||
let start = index.start as i64;
|
||||
let length = (index.end - index.start) as i64;
|
||||
|
||||
|
@ -59,10 +59,10 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
|
|||
pub fn gather<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: TchTensor<E, D>,
|
||||
indexes: TchTensor<i64, D>,
|
||||
indices: TchTensor<i64, D>,
|
||||
) -> TchTensor<E, D> {
|
||||
let storage = tensor.storage.clone();
|
||||
let tensor = tensor.tensor.gather(dim as i64, &indexes.tensor, false);
|
||||
let tensor = tensor.tensor.gather(dim as i64, &indices.tensor, false);
|
||||
|
||||
TchTensor::from_existing(tensor, storage)
|
||||
}
|
||||
|
@ -70,13 +70,13 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
|
|||
pub fn scatter<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: TchTensor<E, D>,
|
||||
indexes: TchTensor<i64, D>,
|
||||
indices: TchTensor<i64, D>,
|
||||
value: TchTensor<E, D>,
|
||||
) -> TchTensor<E, D> {
|
||||
let storage = tensor.storage.clone();
|
||||
let tensor = tensor
|
||||
.tensor
|
||||
.scatter_add(dim as i64, &indexes.tensor, &value.tensor);
|
||||
.scatter_add(dim as i64, &indices.tensor, &value.tensor);
|
||||
|
||||
TchTensor::from_existing(tensor, storage)
|
||||
}
|
||||
|
@ -84,25 +84,25 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
|
|||
pub fn index_select_dim<const D: usize>(
|
||||
tensor: TchTensor<E, D>,
|
||||
dim: usize,
|
||||
indexes: TchTensor<i64, 1>,
|
||||
indices: TchTensor<i64, 1>,
|
||||
) -> TchTensor<E, D> {
|
||||
let storage = tensor.storage.clone();
|
||||
let tensor = tensor.tensor.index_select(dim as i64, &indexes.tensor);
|
||||
let tensor = tensor.tensor.index_select(dim as i64, &indices.tensor);
|
||||
|
||||
TchTensor::from_existing(tensor, storage)
|
||||
}
|
||||
|
||||
pub fn index_select_dim_assign<const D1: usize, const D2: usize>(
|
||||
tensor: TchTensor<E, D1>,
|
||||
pub fn select_assign<const D: usize>(
|
||||
tensor: TchTensor<E, D>,
|
||||
dim: usize,
|
||||
indexes: TchTensor<i64, 1>,
|
||||
value: TchTensor<E, D2>,
|
||||
) -> TchTensor<E, D1> {
|
||||
let mut indices = Vec::with_capacity(D1);
|
||||
for _ in 0..D1 {
|
||||
indices_tensor: TchTensor<i64, 1>,
|
||||
value: TchTensor<E, D>,
|
||||
) -> TchTensor<E, D> {
|
||||
let mut indices = Vec::with_capacity(D);
|
||||
for _ in 0..D {
|
||||
indices.push(None);
|
||||
}
|
||||
indices[dim] = Some(indexes.tensor);
|
||||
indices[dim] = Some(indices_tensor.tensor);
|
||||
|
||||
tensor.unary_ops(
|
||||
|mut tensor| tensor.index_put_(&indices, &value.tensor, true),
|
||||
|
@ -321,41 +321,41 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
|
|||
|
||||
pub fn max_dim<const D: usize>(tensor: TchTensor<E, D>, dim: usize) -> TchTensor<E, D> {
|
||||
let storage = tensor.storage.clone();
|
||||
let (tensor, _indexes) = tensor.tensor.max_dim(dim as i64, true);
|
||||
let (tensor, _indices) = tensor.tensor.max_dim(dim as i64, true);
|
||||
|
||||
TchTensor::from_existing(tensor, storage)
|
||||
}
|
||||
|
||||
pub fn max_dim_with_indexes<const D: usize>(
|
||||
pub fn max_dim_with_indices<const D: usize>(
|
||||
tensor: TchTensor<E, D>,
|
||||
dim: usize,
|
||||
) -> (TchTensor<E, D>, TchTensor<i64, D>) {
|
||||
let storage = tensor.storage.clone();
|
||||
let (tensor, indexes) = tensor.tensor.max_dim(dim as i64, true);
|
||||
let (tensor, indices) = tensor.tensor.max_dim(dim as i64, true);
|
||||
|
||||
let tensor = TchTensor::from_existing(tensor, storage);
|
||||
let indexes = TchTensor::new(indexes);
|
||||
let indices = TchTensor::new(indices);
|
||||
|
||||
(tensor, indexes)
|
||||
(tensor, indices)
|
||||
}
|
||||
|
||||
pub fn min_dim<const D: usize>(tensor: TchTensor<E, D>, dim: usize) -> TchTensor<E, D> {
|
||||
let storage = tensor.storage.clone();
|
||||
let (tensor, _indexes) = tensor.tensor.min_dim(dim as i64, true);
|
||||
let (tensor, _indices) = tensor.tensor.min_dim(dim as i64, true);
|
||||
|
||||
TchTensor::from_existing(tensor, storage)
|
||||
}
|
||||
|
||||
pub fn min_dim_with_indexes<const D: usize>(
|
||||
pub fn min_dim_with_indices<const D: usize>(
|
||||
tensor: TchTensor<E, D>,
|
||||
dim: usize,
|
||||
) -> (TchTensor<E, D>, TchTensor<i64, D>) {
|
||||
let storage = tensor.storage.clone();
|
||||
let (tensor, indexes) = tensor.tensor.min_dim(dim as i64, true);
|
||||
let (tensor, indices) = tensor.tensor.min_dim(dim as i64, true);
|
||||
|
||||
let tensor = TchTensor::from_existing(tensor, storage);
|
||||
let indexes = TchTensor::new(indexes);
|
||||
let indices = TchTensor::new(indices);
|
||||
|
||||
(tensor, indexes)
|
||||
(tensor, indices)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -60,18 +60,18 @@ impl<E: TchElement> BoolTensorOps<TchBackend<E>> for TchBackend<E> {
|
|||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
fn bool_index<const D1: usize, const D2: usize>(
|
||||
fn bool_slice<const D1: usize, const D2: usize>(
|
||||
tensor: TchTensor<bool, D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
ranges: [Range<usize>; D2],
|
||||
) -> TchTensor<bool, D1> {
|
||||
TchOps::index(tensor, indexes)
|
||||
TchOps::slice(tensor, ranges)
|
||||
}
|
||||
fn bool_index_assign<const D1: usize, const D2: usize>(
|
||||
fn bool_slice_assign<const D1: usize, const D2: usize>(
|
||||
tensor: TchTensor<bool, D1>,
|
||||
indexes: [std::ops::Range<usize>; D2],
|
||||
ranges: [std::ops::Range<usize>; D2],
|
||||
value: TchTensor<bool, D1>,
|
||||
) -> TchTensor<bool, D1> {
|
||||
TchOps::index_assign(tensor, indexes, value)
|
||||
TchOps::slice_assign(tensor, ranges, value)
|
||||
}
|
||||
|
||||
fn bool_cat<const D: usize>(
|
||||
|
|
|
@ -57,18 +57,18 @@ impl<E: TchElement> IntTensorOps<TchBackend<E>> for TchBackend<E> {
|
|||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
fn int_index<const D1: usize, const D2: usize>(
|
||||
fn int_slice<const D1: usize, const D2: usize>(
|
||||
tensor: TchTensor<i64, D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
ranges: [Range<usize>; D2],
|
||||
) -> TchTensor<i64, D1> {
|
||||
TchOps::index(tensor, indexes)
|
||||
TchOps::slice(tensor, ranges)
|
||||
}
|
||||
fn int_index_assign<const D1: usize, const D2: usize>(
|
||||
fn int_slice_assign<const D1: usize, const D2: usize>(
|
||||
tensor: TchTensor<i64, D1>,
|
||||
indexes: [std::ops::Range<usize>; D2],
|
||||
ranges: [std::ops::Range<usize>; D2],
|
||||
value: TchTensor<i64, D1>,
|
||||
) -> TchTensor<i64, D1> {
|
||||
TchOps::index_assign(tensor, indexes, value)
|
||||
TchOps::slice_assign(tensor, ranges, value)
|
||||
}
|
||||
|
||||
fn int_cat<const D: usize>(tensors: Vec<TchTensor<i64, D>>, dim: usize) -> TchTensor<i64, D> {
|
||||
|
@ -234,35 +234,35 @@ impl<E: TchElement> IntTensorOps<TchBackend<E>> for TchBackend<E> {
|
|||
fn int_gather<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: TchTensor<i64, D>,
|
||||
indexes: TchTensor<i64, D>,
|
||||
indices: TchTensor<i64, D>,
|
||||
) -> TchTensor<i64, D> {
|
||||
TchOps::gather(dim, tensor, indexes)
|
||||
TchOps::gather(dim, tensor, indices)
|
||||
}
|
||||
|
||||
fn int_scatter<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: TchTensor<i64, D>,
|
||||
indexes: TchTensor<i64, D>,
|
||||
indices: TchTensor<i64, D>,
|
||||
value: TchTensor<i64, D>,
|
||||
) -> TchTensor<i64, D> {
|
||||
TchOps::scatter(dim, tensor, indexes, value)
|
||||
TchOps::scatter(dim, tensor, indices, value)
|
||||
}
|
||||
|
||||
fn int_index_select_dim<const D: usize>(
|
||||
fn int_select<const D: usize>(
|
||||
tensor: TchTensor<i64, D>,
|
||||
dim: usize,
|
||||
indexes: TchTensor<i64, 1>,
|
||||
indices: TchTensor<i64, 1>,
|
||||
) -> TchTensor<i64, D> {
|
||||
TchOps::index_select_dim(tensor, dim, indexes)
|
||||
TchOps::index_select_dim(tensor, dim, indices)
|
||||
}
|
||||
|
||||
fn int_index_select_dim_assign<const D1: usize, const D2: usize>(
|
||||
tensor: TchTensor<i64, D1>,
|
||||
fn int_select_assign<const D: usize>(
|
||||
tensor: TchTensor<i64, D>,
|
||||
dim: usize,
|
||||
indexes: TchTensor<i64, 1>,
|
||||
value: TchTensor<i64, D2>,
|
||||
) -> TchTensor<i64, D1> {
|
||||
TchOps::index_select_dim_assign(tensor, dim, indexes, value)
|
||||
indices: TchTensor<i64, 1>,
|
||||
value: TchTensor<i64, D>,
|
||||
) -> TchTensor<i64, D> {
|
||||
TchOps::select_assign(tensor, dim, indices, value)
|
||||
}
|
||||
|
||||
fn int_mask_where<const D: usize>(
|
||||
|
@ -301,21 +301,21 @@ impl<E: TchElement> IntTensorOps<TchBackend<E>> for TchBackend<E> {
|
|||
TchOps::max_dim(tensor, dim)
|
||||
}
|
||||
|
||||
fn int_max_dim_with_indexes<const D: usize>(
|
||||
fn int_max_dim_with_indices<const D: usize>(
|
||||
tensor: TchTensor<i64, D>,
|
||||
dim: usize,
|
||||
) -> (TchTensor<i64, D>, TchTensor<i64, D>) {
|
||||
TchOps::max_dim_with_indexes(tensor, dim)
|
||||
TchOps::max_dim_with_indices(tensor, dim)
|
||||
}
|
||||
|
||||
fn int_min_dim<const D: usize>(tensor: TchTensor<i64, D>, dim: usize) -> TchTensor<i64, D> {
|
||||
TchOps::min_dim(tensor, dim)
|
||||
}
|
||||
|
||||
fn int_min_dim_with_indexes<const D: usize>(
|
||||
fn int_min_dim_with_indices<const D: usize>(
|
||||
tensor: TchTensor<i64, D>,
|
||||
dim: usize,
|
||||
) -> (TchTensor<i64, D>, TchTensor<i64, D>) {
|
||||
TchOps::min_dim_with_indexes(tensor, dim)
|
||||
TchOps::min_dim_with_indices(tensor, dim)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
use crate::{element::TchElement, TchBackend, TchTensor};
|
||||
use burn_tensor::ops::{
|
||||
ConvOptions, ConvTransposeOptions, MaxPool2dBackward, MaxPool2dWithIndexes, ModuleOps,
|
||||
ConvOptions, ConvTransposeOptions, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps,
|
||||
};
|
||||
|
||||
impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
|
||||
fn embedding(weights: TchTensor<E, 2>, indexes: TchTensor<i64, 2>) -> TchTensor<E, 3> {
|
||||
let tensor = tch::Tensor::embedding(&weights.tensor, &indexes.tensor, -1, false, false);
|
||||
fn embedding(weights: TchTensor<E, 2>, indices: TchTensor<i64, 2>) -> TchTensor<E, 3> {
|
||||
let tensor = tch::Tensor::embedding(&weights.tensor, &indices.tensor, -1, false, false);
|
||||
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
|
@ -13,12 +13,12 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
|
|||
fn embedding_backward(
|
||||
weights: TchTensor<E, 2>,
|
||||
output: TchTensor<E, 3>,
|
||||
indexes: TchTensor<i64, 2>,
|
||||
indices: TchTensor<i64, 2>,
|
||||
) -> TchTensor<E, 2> {
|
||||
let [n_embedding, _d_model] = weights.shape().dims;
|
||||
let tensor = tch::Tensor::embedding_backward(
|
||||
&output.tensor,
|
||||
&indexes.tensor,
|
||||
&indices.tensor,
|
||||
n_embedding as i64,
|
||||
-1,
|
||||
false,
|
||||
|
@ -181,13 +181,13 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
|
|||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
fn max_pool2d_with_indexes(
|
||||
fn max_pool2d_with_indices(
|
||||
x: TchTensor<E, 4>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
) -> MaxPool2dWithIndexes<TchBackend<E>> {
|
||||
let (tensor, indexes) = tch::Tensor::max_pool2d_with_indices(
|
||||
) -> MaxPool2dWithIndices<TchBackend<E>> {
|
||||
let (tensor, indices) = tch::Tensor::max_pool2d_with_indices(
|
||||
&x.tensor,
|
||||
[kernel_size[0] as i64, kernel_size[1] as i64],
|
||||
[stride[0] as i64, stride[1] as i64],
|
||||
|
@ -196,16 +196,16 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
|
|||
false,
|
||||
);
|
||||
|
||||
MaxPool2dWithIndexes::new(TchTensor::new(tensor), TchTensor::new(indexes))
|
||||
MaxPool2dWithIndices::new(TchTensor::new(tensor), TchTensor::new(indices))
|
||||
}
|
||||
|
||||
fn max_pool2d_with_indexes_backward(
|
||||
fn max_pool2d_with_indices_backward(
|
||||
x: TchTensor<E, 4>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
output_grad: TchTensor<E, 4>,
|
||||
indexes: TchTensor<i64, 4>,
|
||||
indices: TchTensor<i64, 4>,
|
||||
) -> MaxPool2dBackward<TchBackend<E>> {
|
||||
let grad = tch::Tensor::max_pool2d_with_indices_backward(
|
||||
&x.tensor,
|
||||
|
@ -215,7 +215,7 @@ impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
|
|||
[padding[0] as i64, padding[1] as i64],
|
||||
[1, 1],
|
||||
false,
|
||||
&indexes.tensor,
|
||||
&indices.tensor,
|
||||
);
|
||||
|
||||
MaxPool2dBackward::new(TchTensor::new(grad))
|
||||
|
|
|
@ -182,50 +182,50 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
|
|||
fn gather<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: TchTensor<E, D>,
|
||||
indexes: TchTensor<i64, D>,
|
||||
indices: TchTensor<i64, D>,
|
||||
) -> TchTensor<E, D> {
|
||||
TchOps::gather(dim, tensor, indexes)
|
||||
TchOps::gather(dim, tensor, indices)
|
||||
}
|
||||
|
||||
fn scatter<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: TchTensor<E, D>,
|
||||
indexes: TchTensor<i64, D>,
|
||||
indices: TchTensor<i64, D>,
|
||||
value: TchTensor<E, D>,
|
||||
) -> TchTensor<E, D> {
|
||||
TchOps::scatter(dim, tensor, indexes, value)
|
||||
TchOps::scatter(dim, tensor, indices, value)
|
||||
}
|
||||
|
||||
fn index_select<const D: usize>(
|
||||
fn select<const D: usize>(
|
||||
tensor: TchTensor<E, D>,
|
||||
dim: usize,
|
||||
indexes: TchTensor<i64, 1>,
|
||||
indices: TchTensor<i64, 1>,
|
||||
) -> TchTensor<E, D> {
|
||||
TchOps::index_select_dim(tensor, dim, indexes)
|
||||
TchOps::index_select_dim(tensor, dim, indices)
|
||||
}
|
||||
|
||||
fn index_select_assign<const D1: usize, const D2: usize>(
|
||||
tensor: TchTensor<E, D1>,
|
||||
fn select_assign<const D: usize>(
|
||||
tensor: TchTensor<E, D>,
|
||||
dim: usize,
|
||||
indexes: TchTensor<i64, 1>,
|
||||
value: TchTensor<E, D2>,
|
||||
) -> TchTensor<E, D1> {
|
||||
TchOps::index_select_dim_assign(tensor, dim, indexes, value)
|
||||
indices: TchTensor<i64, 1>,
|
||||
value: TchTensor<E, D>,
|
||||
) -> TchTensor<E, D> {
|
||||
TchOps::select_assign(tensor, dim, indices, value)
|
||||
}
|
||||
|
||||
fn index<const D1: usize, const D2: usize>(
|
||||
fn slice<const D1: usize, const D2: usize>(
|
||||
tensor: TchTensor<E, D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
ranges: [Range<usize>; D2],
|
||||
) -> TchTensor<E, D1> {
|
||||
TchOps::index(tensor, indexes)
|
||||
TchOps::slice(tensor, ranges)
|
||||
}
|
||||
|
||||
fn index_assign<const D1: usize, const D2: usize>(
|
||||
fn slice_assign<const D1: usize, const D2: usize>(
|
||||
tensor: TchTensor<E, D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
ranges: [Range<usize>; D2],
|
||||
value: TchTensor<E, D1>,
|
||||
) -> <TchBackend<E> as Backend>::TensorPrimitive<D1> {
|
||||
TchOps::index_assign(tensor, indexes, value)
|
||||
TchOps::slice_assign(tensor, ranges, value)
|
||||
}
|
||||
|
||||
fn mask_where<const D: usize>(
|
||||
|
@ -339,22 +339,22 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
|
|||
TchOps::max_dim(tensor, dim)
|
||||
}
|
||||
|
||||
fn max_dim_with_indexes<const D: usize>(
|
||||
fn max_dim_with_indices<const D: usize>(
|
||||
tensor: TchTensor<E, D>,
|
||||
dim: usize,
|
||||
) -> (TchTensor<E, D>, TchTensor<i64, D>) {
|
||||
TchOps::max_dim_with_indexes(tensor, dim)
|
||||
TchOps::max_dim_with_indices(tensor, dim)
|
||||
}
|
||||
|
||||
fn min_dim<const D: usize>(tensor: TchTensor<E, D>, dim: usize) -> TchTensor<E, D> {
|
||||
TchOps::min_dim(tensor, dim)
|
||||
}
|
||||
|
||||
fn min_dim_with_indexes<const D: usize>(
|
||||
fn min_dim_with_indices<const D: usize>(
|
||||
tensor: TchTensor<E, D>,
|
||||
dim: usize,
|
||||
) -> (TchTensor<E, D>, TchTensor<i64, D>) {
|
||||
TchOps::min_dim_with_indexes(tensor, dim)
|
||||
TchOps::min_dim_with_indices(tensor, dim)
|
||||
}
|
||||
|
||||
fn exp<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
|
||||
|
|
|
@ -268,11 +268,11 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn should_not_update_inplace_after_index() {
|
||||
fn should_not_update_inplace_after_slice() {
|
||||
let tensor_1 = Tensor::<TchBackend<f32>, 1>::from_floats([4.0, 4.0]);
|
||||
let tensor_2 = tensor_1.clone();
|
||||
|
||||
let tensor_3 = tensor_2.index([0..2]).add_scalar(2.0);
|
||||
let tensor_3 = tensor_2.slice([0..2]).add_scalar(2.0);
|
||||
|
||||
assert_ne!(tensor_3.to_data().value, tensor_1.to_data().value);
|
||||
}
|
||||
|
|
|
@ -211,18 +211,18 @@ where
|
|||
///
|
||||
/// fn example<B: Backend>() {
|
||||
/// let tensor = Tensor::<B, 3>::ones(Shape::new([2, 3, 3]));
|
||||
/// let tensor_indexed = tensor.index([0..1, 0..3, 1..2]);
|
||||
/// println!("{:?}", tensor_indexed.shape());
|
||||
/// // Shape { dims: [1, 3, 2] }
|
||||
/// let tensor_slices = tensor.slice([0..1, 0..3, 1..2]);
|
||||
/// println!("{:?}", tensor_slices.dims()); // [1, 3, 2]
|
||||
///
|
||||
/// }
|
||||
/// ```
|
||||
pub fn index<const D2: usize>(self, indexes: [core::ops::Range<usize>; D2]) -> Self {
|
||||
check!(TensorCheck::index(&self.shape(), &indexes));
|
||||
Self::new(K::index(self.primitive, indexes))
|
||||
pub fn slice<const D2: usize>(self, ranges: [core::ops::Range<usize>; D2]) -> Self {
|
||||
check!(TensorCheck::slice(&self.shape(), &ranges));
|
||||
Self::new(K::slice(self.primitive, ranges))
|
||||
}
|
||||
|
||||
/// Returns a copy of the current tensor with the selected elements changed to the new ones at
|
||||
/// the selected indexes.
|
||||
/// the selected indices.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
|
@ -238,22 +238,21 @@ where
|
|||
/// fn example<B: Backend>() {
|
||||
/// let tensor = Tensor::<B, 3>::ones([2, 3, 3]);
|
||||
/// let values = Tensor::<B, 3>::zeros([1, 1, 1]);
|
||||
/// let tensor_indexed = tensor.index_assign([0..1, 0..1, 0..1], values);
|
||||
/// println!("{:?}", tensor_indexed.shape());
|
||||
/// // Shape { dims: [2, 3, 3] }
|
||||
/// let tensor_sliced = tensor.slice_assign([0..1, 0..1, 0..1], values);
|
||||
/// println!("{:?}", tensor_sliced.dims()); // [2, 3, 3]
|
||||
/// }
|
||||
/// ```
|
||||
pub fn index_assign<const D2: usize>(
|
||||
pub fn slice_assign<const D2: usize>(
|
||||
self,
|
||||
indexes: [core::ops::Range<usize>; D2],
|
||||
ranges: [core::ops::Range<usize>; D2],
|
||||
values: Self,
|
||||
) -> Self {
|
||||
check!(TensorCheck::index_assign(
|
||||
check!(TensorCheck::slice_assign(
|
||||
&self.shape(),
|
||||
&values.shape(),
|
||||
&indexes
|
||||
&ranges
|
||||
));
|
||||
Self::new(K::index_assign(self.primitive, indexes, values.primitive))
|
||||
Self::new(K::slice_assign(self.primitive, ranges, values.primitive))
|
||||
}
|
||||
|
||||
/// Returns the device of the current tensor.
|
||||
|
@ -358,7 +357,7 @@ where
|
|||
multi_index[depth] = i;
|
||||
let range: [core::ops::Range<usize>; D] =
|
||||
core::array::from_fn(|i| multi_index[i]..multi_index[i] + 1);
|
||||
let elem = &self.clone().index(range).to_data().value[0];
|
||||
let elem = &self.clone().slice(range).to_data().value[0];
|
||||
acc.push_str(&format!("{elem:?}"));
|
||||
}
|
||||
} else {
|
||||
|
@ -480,12 +479,12 @@ pub trait BasicOps<B: Backend>: TensorKind<B> {
|
|||
shape: Shape<D2>,
|
||||
) -> Self::Primitive<D2>;
|
||||
|
||||
/// Select tensor elements corresponding for the given indexes.
|
||||
/// Select tensor elements corresponding for the given ranges.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor.
|
||||
/// * `indexes` - The indexes of the elements to select.
|
||||
/// * `ranges` - The ranges of the elements to select.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
|
@ -497,19 +496,19 @@ pub trait BasicOps<B: Backend>: TensorKind<B> {
|
|||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For selecting elements of a tensor, users should prefer the [Tensor::index](Tensor::index) function,
|
||||
/// For selecting elements of a tensor, users should prefer the [Tensor::slice](Tensor::slice) function,
|
||||
/// which is more high-level and designed for public use.
|
||||
fn index<const D1: usize, const D2: usize>(
|
||||
fn slice<const D1: usize, const D2: usize>(
|
||||
tensor: Self::Primitive<D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
range: [Range<usize>; D2],
|
||||
) -> Self::Primitive<D1>;
|
||||
|
||||
/// Assigns the given value to the tensor elements corresponding for the given indexes.
|
||||
/// Assigns the given value to the tensor elements corresponding for the given ranges.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor.
|
||||
/// * `indexes` - The indexes of the elements to select.
|
||||
/// * `ranges` - The ranges of the elements to select.
|
||||
/// * `value` - The value to assign.
|
||||
///
|
||||
/// # Returns
|
||||
|
@ -522,11 +521,11 @@ pub trait BasicOps<B: Backend>: TensorKind<B> {
|
|||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For assigning values to elements of a tensor, users should prefer the [Tensor::index_assign](Tensor::index_assign) function,
|
||||
/// For assigning values to elements of a tensor, users should prefer the [Tensor::slice_assign](Tensor::slice_assign) function,
|
||||
/// which is more high-level and designed for public use.
|
||||
fn index_assign<const D1: usize, const D2: usize>(
|
||||
fn slice_assign<const D1: usize, const D2: usize>(
|
||||
tensor: Self::Primitive<D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
ranges: [Range<usize>; D2],
|
||||
value: Self::Primitive<D1>,
|
||||
) -> Self::Primitive<D1>;
|
||||
|
||||
|
@ -712,19 +711,19 @@ impl<B: Backend> BasicOps<B> for Float {
|
|||
B::reshape(tensor, shape)
|
||||
}
|
||||
|
||||
fn index<const D1: usize, const D2: usize>(
|
||||
fn slice<const D1: usize, const D2: usize>(
|
||||
tensor: Self::Primitive<D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
ranges: [Range<usize>; D2],
|
||||
) -> Self::Primitive<D1> {
|
||||
B::index(tensor, indexes)
|
||||
B::slice(tensor, ranges)
|
||||
}
|
||||
|
||||
fn index_assign<const D1: usize, const D2: usize>(
|
||||
fn slice_assign<const D1: usize, const D2: usize>(
|
||||
tensor: Self::Primitive<D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
ranges: [Range<usize>; D2],
|
||||
value: Self::Primitive<D1>,
|
||||
) -> Self::Primitive<D1> {
|
||||
B::index_assign(tensor, indexes, value)
|
||||
B::slice_assign(tensor, ranges, value)
|
||||
}
|
||||
|
||||
fn device<const D: usize>(tensor: &Self::Primitive<D>) -> <B as Backend>::Device {
|
||||
|
@ -786,19 +785,19 @@ impl<B: Backend> BasicOps<B> for Int {
|
|||
B::int_reshape(tensor, shape)
|
||||
}
|
||||
|
||||
fn index<const D1: usize, const D2: usize>(
|
||||
fn slice<const D1: usize, const D2: usize>(
|
||||
tensor: Self::Primitive<D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
ranges: [Range<usize>; D2],
|
||||
) -> Self::Primitive<D1> {
|
||||
B::int_index(tensor, indexes)
|
||||
B::int_slice(tensor, ranges)
|
||||
}
|
||||
|
||||
fn index_assign<const D1: usize, const D2: usize>(
|
||||
fn slice_assign<const D1: usize, const D2: usize>(
|
||||
tensor: Self::Primitive<D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
ranges: [Range<usize>; D2],
|
||||
value: Self::Primitive<D1>,
|
||||
) -> Self::Primitive<D1> {
|
||||
B::int_index_assign(tensor, indexes, value)
|
||||
B::int_slice_assign(tensor, ranges, value)
|
||||
}
|
||||
|
||||
fn device<const D: usize>(tensor: &Self::Primitive<D>) -> <B as Backend>::Device {
|
||||
|
@ -860,19 +859,19 @@ impl<B: Backend> BasicOps<B> for Bool {
|
|||
B::bool_reshape(tensor, shape)
|
||||
}
|
||||
|
||||
fn index<const D1: usize, const D2: usize>(
|
||||
fn slice<const D1: usize, const D2: usize>(
|
||||
tensor: Self::Primitive<D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
ranges: [Range<usize>; D2],
|
||||
) -> Self::Primitive<D1> {
|
||||
B::bool_index(tensor, indexes)
|
||||
B::bool_slice(tensor, ranges)
|
||||
}
|
||||
|
||||
fn index_assign<const D1: usize, const D2: usize>(
|
||||
fn slice_assign<const D1: usize, const D2: usize>(
|
||||
tensor: Self::Primitive<D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
ranges: [Range<usize>; D2],
|
||||
value: Self::Primitive<D1>,
|
||||
) -> Self::Primitive<D1> {
|
||||
B::bool_index_assign(tensor, indexes, value)
|
||||
B::bool_slice_assign(tensor, ranges, value)
|
||||
}
|
||||
|
||||
fn device<const D: usize>(tensor: &Self::Primitive<D>) -> <B as Backend>::Device {
|
||||
|
|
|
@ -264,56 +264,56 @@ impl TensorCheck {
|
|||
check
|
||||
}
|
||||
|
||||
pub(crate) fn index<const D1: usize, const D2: usize>(
|
||||
pub(crate) fn slice<const D1: usize, const D2: usize>(
|
||||
shape: &Shape<D1>,
|
||||
indexes: &[Range<usize>; D2],
|
||||
ranges: &[Range<usize>; D2],
|
||||
) -> Self {
|
||||
let mut check = Self::Ok;
|
||||
let n_dims_tensor = D1;
|
||||
let n_dims_indexes = D2;
|
||||
let n_dims_ranges = D2;
|
||||
|
||||
if n_dims_tensor < n_dims_indexes {
|
||||
check = check.register("Index",
|
||||
TensorError::new ("The provided indexes array has a higher number of dimensions than the current tensor.")
|
||||
if n_dims_tensor < n_dims_ranges {
|
||||
check = check.register("Slice",
|
||||
TensorError::new ("The provided ranges array has a higher number of dimensions than the current tensor.")
|
||||
.details(
|
||||
format!(
|
||||
"The indexes array must be smaller or equal to the tensor number of dimensions. \
|
||||
Tensor number of dimensions: {n_dims_tensor}, indexes array lenght {n_dims_indexes}."
|
||||
"The ranges array must be smaller or equal to the tensor number of dimensions. \
|
||||
Tensor number of dimensions: {n_dims_tensor}, ranges array lenght {n_dims_ranges}."
|
||||
)));
|
||||
}
|
||||
|
||||
for i in 0..usize::min(D1, D2) {
|
||||
let d_tensor = shape.dims[i];
|
||||
let index = indexes.get(i).unwrap();
|
||||
let range = ranges.get(i).unwrap();
|
||||
|
||||
if index.end > d_tensor {
|
||||
if range.end > d_tensor {
|
||||
check = check.register(
|
||||
"Index",
|
||||
TensorError::new("The provided indexes array has a range that exceeds the current tensor size.")
|
||||
"Slice",
|
||||
TensorError::new("The provided ranges array has a range that exceeds the current tensor size.")
|
||||
.details(format!(
|
||||
"The range ({}..{}) exceeds the size of the tensor ({}) at dimension {}. \
|
||||
Tensor shape {:?}, provided indexes {:?}.",
|
||||
index.start,
|
||||
index.end,
|
||||
Tensor shape {:?}, provided ranges {:?}.",
|
||||
range.start,
|
||||
range.end,
|
||||
d_tensor,
|
||||
i,
|
||||
shape.dims,
|
||||
indexes,
|
||||
ranges,
|
||||
)));
|
||||
}
|
||||
|
||||
if index.start >= index.end {
|
||||
if range.start >= range.end {
|
||||
check = check.register(
|
||||
"Index",
|
||||
TensorError::new("The provided indexes array has a range where the start index is bigger or equal to its end.")
|
||||
"Slice",
|
||||
TensorError::new("The provided range array has a range where the start index is bigger or equal to its end.")
|
||||
.details(format!(
|
||||
"The range at dimension '{}' starts at '{}' and is greater or equal to its end '{}'. \
|
||||
Tensor shape {:?}, provided indexes {:?}.",
|
||||
Tensor shape {:?}, provided ranges {:?}.",
|
||||
i,
|
||||
index.start,
|
||||
index.end,
|
||||
range.start,
|
||||
range.end,
|
||||
shape.dims,
|
||||
indexes,
|
||||
ranges,
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
@ -321,75 +321,75 @@ impl TensorCheck {
|
|||
check
|
||||
}
|
||||
|
||||
pub(crate) fn index_assign<const D1: usize, const D2: usize>(
|
||||
pub(crate) fn slice_assign<const D1: usize, const D2: usize>(
|
||||
shape: &Shape<D1>,
|
||||
shape_value: &Shape<D1>,
|
||||
indexes: &[Range<usize>; D2],
|
||||
ranges: &[Range<usize>; D2],
|
||||
) -> Self {
|
||||
let mut check = Self::Ok;
|
||||
|
||||
if D1 < D2 {
|
||||
check = check.register("Index Assign",
|
||||
TensorError::new ("The provided indexes array has a higher number of dimensions than the current tensor.")
|
||||
check = check.register("Slice Assign",
|
||||
TensorError::new ("The provided ranges array has a higher number of dimensions than the current tensor.")
|
||||
.details(
|
||||
format!(
|
||||
"The indexes array must be smaller or equal to the tensor number of dimensions. \
|
||||
Tensor number of dimensions: {D1}, indexes array lenght {D2}."
|
||||
"The ranges array must be smaller or equal to the tensor number of dimensions. \
|
||||
Tensor number of dimensions: {D1}, ranges array lenght {D2}."
|
||||
)));
|
||||
}
|
||||
|
||||
for i in 0..usize::min(D1, D2) {
|
||||
let d_tensor = shape.dims[i];
|
||||
let d_tensor_value = shape_value.dims[i];
|
||||
let index = indexes.get(i).unwrap();
|
||||
let range = ranges.get(i).unwrap();
|
||||
|
||||
if index.end > d_tensor {
|
||||
if range.end > d_tensor {
|
||||
check = check.register(
|
||||
"Index Assign",
|
||||
TensorError::new("The provided indexes array has a range that exceeds the current tensor size.")
|
||||
"Range Assign",
|
||||
TensorError::new("The provided ranges array has a range that exceeds the current tensor size.")
|
||||
.details(format!(
|
||||
"The range ({}..{}) exceeds the size of the tensor ({}) at dimension {}. \
|
||||
Current tensor shape {:?}, value tensor shape {:?}, provided indexes {:?}.",
|
||||
index.start,
|
||||
index.end,
|
||||
Current tensor shape {:?}, value tensor shape {:?}, provided ranges {:?}.",
|
||||
range.start,
|
||||
range.end,
|
||||
d_tensor,
|
||||
i,
|
||||
shape.dims,
|
||||
shape_value.dims,
|
||||
indexes,
|
||||
ranges,
|
||||
)));
|
||||
}
|
||||
|
||||
if index.end - index.start != d_tensor_value {
|
||||
if range.end - range.start != d_tensor_value {
|
||||
check = check.register(
|
||||
"Index Assign",
|
||||
TensorError::new("The value tensor must match the amount of elements selected with the indexes array")
|
||||
"Slice Assign",
|
||||
TensorError::new("The value tensor must match the amount of elements selected with the ranges array")
|
||||
.details(format!(
|
||||
"The range ({}..{}) doesn't match the number of elements of the value tensor ({}) at dimension {}. \
|
||||
Current tensor shape {:?}, value tensor shape {:?}, provided indexes {:?}.",
|
||||
index.start,
|
||||
index.end,
|
||||
Current tensor shape {:?}, value tensor shape {:?}, provided ranges {:?}.",
|
||||
range.start,
|
||||
range.end,
|
||||
d_tensor_value,
|
||||
i,
|
||||
shape.dims,
|
||||
shape_value.dims,
|
||||
indexes,
|
||||
ranges,
|
||||
)));
|
||||
}
|
||||
|
||||
if index.start >= index.end {
|
||||
if range.start >= range.end {
|
||||
check = check.register(
|
||||
"Index Assign",
|
||||
TensorError::new("The provided indexes array has a range where the start index is bigger or equal to its end.")
|
||||
"Slice Assign",
|
||||
TensorError::new("The provided ranges array has a range where the start index is bigger or equal to its end.")
|
||||
.details(format!(
|
||||
"The range at dimension '{}' starts at '{}' and is greater or equal to its end '{}'. \
|
||||
Current tensor shape {:?}, value tensor shape {:?}, provided indexes {:?}.",
|
||||
Current tensor shape {:?}, value tensor shape {:?}, provided ranges {:?}.",
|
||||
i,
|
||||
index.start,
|
||||
index.end,
|
||||
range.start,
|
||||
range.end,
|
||||
shape.dims,
|
||||
shape_value.dims,
|
||||
indexes,
|
||||
ranges,
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
@ -400,31 +400,31 @@ impl TensorCheck {
|
|||
pub(crate) fn gather<const D: usize>(
|
||||
dim: usize,
|
||||
shape: &Shape<D>,
|
||||
shape_indexes: &Shape<D>,
|
||||
shape_indices: &Shape<D>,
|
||||
) -> Self {
|
||||
Self::check_gather_scatter_indexes(Self::Ok, "Gather", dim, shape, shape_indexes)
|
||||
Self::check_gather_scatter_indices(Self::Ok, "Gather", dim, shape, shape_indices)
|
||||
}
|
||||
|
||||
pub(crate) fn scatter<const D: usize>(
|
||||
dim: usize,
|
||||
shape: &Shape<D>,
|
||||
shape_indexes: &Shape<D>,
|
||||
shape_indices: &Shape<D>,
|
||||
shape_value: &Shape<D>,
|
||||
) -> Self {
|
||||
let ops = "Scatter";
|
||||
let mut check =
|
||||
Self::check_gather_scatter_indexes(Self::Ok, ops, dim, shape, shape_indexes);
|
||||
Self::check_gather_scatter_indices(Self::Ok, ops, dim, shape, shape_indices);
|
||||
|
||||
if shape_indexes != shape_value {
|
||||
if shape_indices != shape_value {
|
||||
check = check.register(
|
||||
ops,
|
||||
TensorError::new(
|
||||
"Indexes tensor shape should be the same as the value tensor shape."
|
||||
"Indices tensor shape should be the same as the value tensor shape."
|
||||
.to_string(),
|
||||
)
|
||||
.details(format!(
|
||||
"The shape differs: {:?} != {:?}",
|
||||
shape_indexes.dims, shape_value.dims
|
||||
shape_indices.dims, shape_value.dims
|
||||
)),
|
||||
);
|
||||
}
|
||||
|
@ -432,15 +432,15 @@ impl TensorCheck {
|
|||
check
|
||||
}
|
||||
|
||||
pub(crate) fn index_select<const D: usize>(dim: usize) -> Self {
|
||||
Self::check_index_select_basic::<D>(Self::Ok, "index_select", dim)
|
||||
pub(crate) fn select<const D: usize>(dim: usize) -> Self {
|
||||
Self::check_select_basic::<D>(Self::Ok, "select", dim)
|
||||
}
|
||||
|
||||
pub(crate) fn index_select_assign<const D: usize>(dim: usize) -> Self {
|
||||
Self::check_index_select_basic::<D>(Self::Ok, "index_select_assign", dim)
|
||||
pub(crate) fn select_assign<const D: usize>(dim: usize) -> Self {
|
||||
Self::check_select_basic::<D>(Self::Ok, "select_assign", dim)
|
||||
}
|
||||
|
||||
fn check_index_select_basic<const D: usize>(mut check: Self, ops: &str, dim: usize) -> Self {
|
||||
fn check_select_basic<const D: usize>(mut check: Self, ops: &str, dim: usize) -> Self {
|
||||
if dim > D {
|
||||
check = check.register(
|
||||
ops,
|
||||
|
@ -452,12 +452,12 @@ impl TensorCheck {
|
|||
|
||||
check
|
||||
}
|
||||
fn check_gather_scatter_indexes<const D: usize>(
|
||||
fn check_gather_scatter_indices<const D: usize>(
|
||||
mut check: Self,
|
||||
ops: &str,
|
||||
dim: usize,
|
||||
shape: &Shape<D>,
|
||||
shape_indexes: &Shape<D>,
|
||||
shape_indices: &Shape<D>,
|
||||
) -> Self {
|
||||
if dim > D {
|
||||
check = check.register(
|
||||
|
@ -474,9 +474,9 @@ impl TensorCheck {
|
|||
}
|
||||
|
||||
let tensor_dim_i = shape.dims[i];
|
||||
let indexes_dim_i = shape_indexes.dims[i];
|
||||
let indices_dim_i = shape_indices.dims[i];
|
||||
|
||||
if tensor_dim_i != indexes_dim_i {
|
||||
if tensor_dim_i != indices_dim_i {
|
||||
check = check.register(
|
||||
ops,
|
||||
TensorError::new(
|
||||
|
@ -484,7 +484,7 @@ impl TensorCheck {
|
|||
.to_string(),
|
||||
)
|
||||
.details(format!(
|
||||
"The shape differs at dimension {i}: {tensor_dim_i} != {indexes_dim_i}"
|
||||
"The shape differs at dimension {i}: {tensor_dim_i} != {indices_dim_i}"
|
||||
)),
|
||||
);
|
||||
}
|
||||
|
@ -669,7 +669,7 @@ mod tests {
|
|||
#[test]
|
||||
#[should_panic]
|
||||
fn index_range_exceed_dimension() {
|
||||
check!(TensorCheck::index(
|
||||
check!(TensorCheck::slice(
|
||||
&Shape::new([3, 5, 7]),
|
||||
&[0..2, 0..4, 1..8]
|
||||
));
|
||||
|
@ -678,7 +678,7 @@ mod tests {
|
|||
#[test]
|
||||
#[should_panic]
|
||||
fn index_range_exceed_number_of_dimensions() {
|
||||
check!(TensorCheck::index(&Shape::new([3, 5]), &[0..1, 0..1, 0..1]));
|
||||
check!(TensorCheck::slice(&Shape::new([3, 5]), &[0..1, 0..1, 0..1]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
@ -142,7 +142,7 @@ where
|
|||
let mut ranges: [core::ops::Range<usize>; D] = ranges.try_into().unwrap();
|
||||
ranges[D - 1] = index..index + 1;
|
||||
|
||||
tensor.index_assign(ranges, Tensor::ones(Shape::new([1; D])))
|
||||
tensor.slice_assign(ranges, Tensor::ones(Shape::new([1; D])))
|
||||
}
|
||||
|
||||
/// Applies the transpose operation.
|
||||
|
|
|
@ -214,36 +214,36 @@ where
|
|||
Self::new(K::mask_fill(self.primitive, mask, value.elem()))
|
||||
}
|
||||
|
||||
/// Gather tensor elements corresponding to the given indexes from the specified dim.
|
||||
/// Gather tensor elements corresponding to the given indices from the specified dim.
|
||||
///
|
||||
/// Example using a 3D tensor:
|
||||
///
|
||||
/// `output[i, j, k] = input[indexes[i, j, k], j, k]; // dim = 0`
|
||||
/// `output[i, j, k] = input[i, indexes[i, j, k], k]; // dim = 1`
|
||||
/// `output[i, j, k] = input[i, j, indexes[i, j, k]]; // dim = 2`
|
||||
/// `output[i, j, k] = input[indices[i, j, k], j, k]; // dim = 0`
|
||||
/// `output[i, j, k] = input[i, indices[i, j, k], k]; // dim = 1`
|
||||
/// `output[i, j, k] = input[i, j, indices[i, j, k]]; // dim = 2`
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// The index tensor shoud have the same shape as the original tensor except for the dim
|
||||
/// specified.
|
||||
pub fn gather(self, dim: usize, indexes: Tensor<B, D, Int>) -> Self {
|
||||
pub fn gather(self, dim: usize, indices: Tensor<B, D, Int>) -> Self {
|
||||
check!(TensorCheck::gather::<D>(
|
||||
dim,
|
||||
&self.shape(),
|
||||
&indexes.shape()
|
||||
&indices.shape()
|
||||
));
|
||||
|
||||
Self::new(K::gather(dim, self.primitive, indexes))
|
||||
Self::new(K::gather(dim, self.primitive, indices))
|
||||
}
|
||||
|
||||
/// Assign the gathered elements corresponding to the given indexes along the speficied dimension
|
||||
/// Assign the gathered elements corresponding to the given indices along the speficied dimension
|
||||
/// from the value tensor to the original tensor using sum reduction.
|
||||
///
|
||||
/// Example using a 3D tensor:
|
||||
///
|
||||
/// `input[indexes[i, j, k], j, k] += values[i, j, k]; // dim = 0`
|
||||
/// `input[i, indexes[i, j, k], k] += values[i, j, k]; // dim = 1`
|
||||
/// `input[i, j, indexes[i, j, k]] += values[i, j, k]; // dim = 2`
|
||||
/// `input[indices[i, j, k], j, k] += values[i, j, k]; // dim = 0`
|
||||
/// `input[i, indices[i, j, k], k] += values[i, j, k]; // dim = 1`
|
||||
/// `input[i, j, indices[i, j, k]] += values[i, j, k]; // dim = 2`
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
|
@ -251,49 +251,49 @@ where
|
|||
/// dimension. The value and index tensors should have the same shape.
|
||||
///
|
||||
/// Other references to the input tensor will not be modified by this operation.
|
||||
pub fn scatter(self, dim: usize, indexes: Tensor<B, D, Int>, values: Self) -> Self {
|
||||
pub fn scatter(self, dim: usize, indices: Tensor<B, D, Int>, values: Self) -> Self {
|
||||
check!(TensorCheck::scatter::<D>(
|
||||
dim,
|
||||
&self.shape(),
|
||||
&indexes.shape(),
|
||||
&indices.shape(),
|
||||
&values.shape()
|
||||
));
|
||||
|
||||
Self::new(K::scatter(dim, self.primitive, indexes, values.primitive))
|
||||
Self::new(K::scatter(dim, self.primitive, indices, values.primitive))
|
||||
}
|
||||
|
||||
/// Select the tensor elements along the given dimension corresponding to the given indexes.
|
||||
/// Select the tensor elements along the given dimension corresponding to the given indices.
|
||||
///
|
||||
/// Example using a 3D tensor:
|
||||
///
|
||||
/// `output[i, j, k] = input[indexes[i], j, k]; // dim = 0`
|
||||
/// `output[i, j, k] = input[i, indexes[j], k]; // dim = 1`
|
||||
/// `output[i, j, k] = input[i, j, indexes[k]]; // dim = 2`
|
||||
pub fn index_select(self, dim: usize, indexes: Tensor<B, 1, Int>) -> Self {
|
||||
check!(TensorCheck::index_select::<D>(dim));
|
||||
Self::new(K::index_select(self.primitive, dim, indexes))
|
||||
/// `output[i, j, k] = input[indices[i], j, k]; // dim = 0`
|
||||
/// `output[i, j, k] = input[i, indices[j], k]; // dim = 1`
|
||||
/// `output[i, j, k] = input[i, j, indices[k]]; // dim = 2`
|
||||
pub fn select(self, dim: usize, indices: Tensor<B, 1, Int>) -> Self {
|
||||
check!(TensorCheck::select::<D>(dim));
|
||||
Self::new(K::select(self.primitive, dim, indices))
|
||||
}
|
||||
|
||||
/// Assign the selected elements along the given dimension corresponding to the given indexes
|
||||
/// Assign the selected elements along the given dimension corresponding to the given indices
|
||||
/// from the value tensor to the original tensor using sum reduction.
|
||||
///
|
||||
/// Example using a 3D tensor:
|
||||
///
|
||||
/// `input[indexes[i], j, k] += values[i, j, k]; // dim = 0`
|
||||
/// `input[i, indexes[j], k] += values[i, j, k]; // dim = 1`
|
||||
/// `input[i, j, indexes[k]] += values[i, j, k]; // dim = 2`
|
||||
pub fn index_select_assign<const D2: usize>(
|
||||
/// `input[indices[i], j, k] += values[i, j, k]; // dim = 0`
|
||||
/// `input[i, indices[j], k] += values[i, j, k]; // dim = 1`
|
||||
/// `input[i, j, indices[k]] += values[i, j, k]; // dim = 2`
|
||||
pub fn select_assign(
|
||||
self,
|
||||
dim: usize,
|
||||
indexes: Tensor<B, 1, Int>,
|
||||
values: Tensor<B, D2, K>,
|
||||
indices: Tensor<B, 1, Int>,
|
||||
values: Tensor<B, D, K>,
|
||||
) -> Self {
|
||||
check!(TensorCheck::index_select_assign::<D>(dim));
|
||||
check!(TensorCheck::select_assign::<D>(dim));
|
||||
|
||||
Self::new(K::index_select_assign(
|
||||
Self::new(K::select_assign(
|
||||
self.primitive,
|
||||
dim,
|
||||
indexes,
|
||||
indices,
|
||||
values.primitive,
|
||||
))
|
||||
}
|
||||
|
@ -331,11 +331,11 @@ where
|
|||
|
||||
/// Find the maximum value along the given dimension.
|
||||
///
|
||||
/// Also returns the indexes.
|
||||
pub fn max_dim_with_indexes(self, dim: usize) -> (Tensor<B, D, K>, Tensor<B, D, Int>) {
|
||||
/// Also returns the indices.
|
||||
pub fn max_dim_with_indices(self, dim: usize) -> (Tensor<B, D, K>, Tensor<B, D, Int>) {
|
||||
check!(TensorCheck::aggregate_dim::<D>("Max", dim));
|
||||
|
||||
let (tensor, index) = K::max_dim_with_indexes(self.primitive, dim);
|
||||
let (tensor, index) = K::max_dim_with_indices(self.primitive, dim);
|
||||
|
||||
let tensor = Tensor::new(tensor);
|
||||
let index = Tensor::new(index);
|
||||
|
@ -375,11 +375,11 @@ where
|
|||
|
||||
/// Find the minimum value along the given dimension.
|
||||
///
|
||||
/// Also returns the indexes.
|
||||
pub fn min_dim_with_indexes(self, dim: usize) -> (Tensor<B, D, K>, Tensor<B, D, Int>) {
|
||||
/// Also returns the indices.
|
||||
pub fn min_dim_with_indices(self, dim: usize) -> (Tensor<B, D, K>, Tensor<B, D, Int>) {
|
||||
check!(TensorCheck::aggregate_dim::<D>("Min", dim));
|
||||
|
||||
let (tensor, index) = K::min_dim_with_indexes(self.primitive, dim);
|
||||
let (tensor, index) = K::min_dim_with_indices(self.primitive, dim);
|
||||
|
||||
let tensor = Tensor::new(tensor);
|
||||
let index = Tensor::new(index);
|
||||
|
@ -1008,7 +1008,7 @@ where
|
|||
///
|
||||
/// * `dim` - The axis along which to gather elements.
|
||||
/// * `tensor` - The tensor to gather elements from.
|
||||
/// * `indexes` - The indexes of the elements to gather.
|
||||
/// * `indices` - The indices of the elements to gather.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
|
@ -1026,7 +1026,7 @@ where
|
|||
fn gather<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: Self::Primitive<D>,
|
||||
indexes: Tensor<B, D, Int>,
|
||||
indices: Tensor<B, D, Int>,
|
||||
) -> Self::Primitive<D>;
|
||||
|
||||
/// Scatters elements into a tensor along an axis.
|
||||
|
@ -1035,14 +1035,14 @@ where
|
|||
///
|
||||
/// * `dim` - The axis along which to scatter elements.
|
||||
/// * `tensor` - The tensor to scatter elements into.
|
||||
/// * `indices` - The indexes of the elements to scatter.
|
||||
/// * `indices` - The indices of the elements to scatter.
|
||||
/// * `values` - The values to scatter into the tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor with the same shape as the input tensor, where each element is taken from the
|
||||
/// corresponding element of the input tensor at the corresponding index along the specified axis,
|
||||
/// except for the elements at the specified indexes, which are taken from the corresponding
|
||||
/// except for the elements at the specified indices, which are taken from the corresponding
|
||||
/// element of the values tensor.
|
||||
///
|
||||
/// # Remarks
|
||||
|
@ -1056,17 +1056,17 @@ where
|
|||
fn scatter<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: Self::Primitive<D>,
|
||||
indexes: Tensor<B, D, Int>,
|
||||
indices: Tensor<B, D, Int>,
|
||||
values: Self::Primitive<D>,
|
||||
) -> Self::Primitive<D>;
|
||||
|
||||
/// Select tensor elements along the given dimension corresponding for the given indexes.
|
||||
/// Select tensor elements along the given dimension corresponding for the given indices.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to select elements from.
|
||||
/// * `dim` - The axis along which to select elements.
|
||||
/// * `indexes` - The indexes of the elements to select.
|
||||
/// * `indices` - The indices of the elements to select.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
|
@ -1080,28 +1080,28 @@ where
|
|||
/// or use this function directly.
|
||||
///
|
||||
/// For selecting elements from a tensor along an axis, users should prefer the
|
||||
/// [Tensor::index_select](Tensor::index_select) function, which is more high-level and designed for public use.
|
||||
fn index_select<const D: usize>(
|
||||
/// [Tensor::select](Tensor::select) function, which is more high-level and designed for public use.
|
||||
fn select<const D: usize>(
|
||||
tensor: Self::Primitive<D>,
|
||||
dim: usize,
|
||||
indexes: Tensor<B, 1, Int>,
|
||||
indices: Tensor<B, 1, Int>,
|
||||
) -> Self::Primitive<D>;
|
||||
|
||||
/// Assign the selected elements along the given dimension corresponding to the given indexes
|
||||
/// Assign the selected elements along the given dimension corresponding to the given indices
|
||||
/// from the value tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to assign elements to.
|
||||
/// * `dim` - The axis along which to assign elements.
|
||||
/// * `indexes` - The indexes of the elements to assign.
|
||||
/// * `indices` - The indices of the elements to assign.
|
||||
/// * `values` - The values to assign to the tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor with the same shape as the input tensor, where each element is taken from the
|
||||
/// corresponding element of the input tensor at the corresponding index along the specified axis,
|
||||
/// except for the elements at the specified indexes, which are taken from the corresponding
|
||||
/// except for the elements at the specified indices, which are taken from the corresponding
|
||||
/// element of the values tensor.
|
||||
///
|
||||
/// # Remarks
|
||||
|
@ -1111,20 +1111,20 @@ where
|
|||
/// or use this function directly.
|
||||
///
|
||||
/// For assigning elements to a tensor along an axis, users should prefer the
|
||||
/// [Tensor::index_select_assign](Tensor::index_select_assign) function, which is more high-level and designed for public use.
|
||||
fn index_select_assign<const D1: usize, const D2: usize>(
|
||||
tensor: Self::Primitive<D1>,
|
||||
/// [Tensor::select_assign](Tensor::select_assign) function, which is more high-level and designed for public use.
|
||||
fn select_assign<const D: usize>(
|
||||
tensor: Self::Primitive<D>,
|
||||
dim: usize,
|
||||
indexes: Tensor<B, 1, Int>,
|
||||
values: Self::Primitive<D2>,
|
||||
) -> Self::Primitive<D1>;
|
||||
indices: Tensor<B, 1, Int>,
|
||||
values: Self::Primitive<D>,
|
||||
) -> Self::Primitive<D>;
|
||||
|
||||
/// Gets the indexes of the maximum elements of a tensor along an axis.
|
||||
/// Gets the indices of the maximum elements of a tensor along an axis.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `dim` - The axis along which to get the indexes of the maximum elements.
|
||||
/// * `tensor` - The tensor to get the indexes of the maximum elements from.
|
||||
/// * `dim` - The axis along which to get the indices of the maximum elements.
|
||||
/// * `tensor` - The tensor to get the indices of the maximum elements from.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
|
@ -1137,16 +1137,16 @@ where
|
|||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For getting the indexes of the maximum elements of a tensor along an axis, users should prefer the
|
||||
/// For getting the indices of the maximum elements of a tensor along an axis, users should prefer the
|
||||
/// [Tensor::argmax](Tensor::argmax) function, which is more high-level and designed for public use.
|
||||
fn argmax<const D: usize>(tensor: Self::Primitive<D>, dim: usize) -> B::IntTensorPrimitive<D>;
|
||||
|
||||
/// Gets the indexes of the minimum elements of a tensor along an axis.
|
||||
/// Gets the indices of the minimum elements of a tensor along an axis.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `dim` - The axis along which to get the indexes of the minimum elements.
|
||||
/// * `tensor` - The tensor to get the indexes of the minimum elements from.
|
||||
/// * `dim` - The axis along which to get the indices of the minimum elements.
|
||||
/// * `tensor` - The tensor to get the indices of the minimum elements from.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
|
@ -1159,7 +1159,7 @@ where
|
|||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For getting the indexes of the minimum elements of a tensor along an axis, users should prefer the
|
||||
/// For getting the indices of the minimum elements of a tensor along an axis, users should prefer the
|
||||
/// [Tensor::argmin](Tensor::argmin) function, which is more high-level and designed for public use.
|
||||
fn argmin<const D: usize>(tensor: Self::Primitive<D>, dim: usize) -> B::IntTensorPrimitive<D>;
|
||||
|
||||
|
@ -1224,8 +1224,8 @@ where
|
|||
/// or use this function directly.
|
||||
///
|
||||
/// For getting the maximum elements of a tensor along an axis, users should prefer the
|
||||
/// [Tensor::max_dim_with_indexes](Tensor::max_dim_with_indexes) function, which is more high-level and designed for public use.
|
||||
fn max_dim_with_indexes<const D: usize>(
|
||||
/// [Tensor::max_dim_with_indices](Tensor::max_dim_with_indices) function, which is more high-level and designed for public use.
|
||||
fn max_dim_with_indices<const D: usize>(
|
||||
tensor: Self::Primitive<D>,
|
||||
dim: usize,
|
||||
) -> (Self::Primitive<D>, B::IntTensorPrimitive<D>);
|
||||
|
@ -1291,8 +1291,8 @@ where
|
|||
/// or use this function directly.
|
||||
///
|
||||
/// For getting the minimum elements of a tensor along an axis, users should prefer the
|
||||
/// [Tensor::min_dim_with_indexes](Tensor::min_dim_with_indexes) function, which is more high-level and designed for public use.
|
||||
fn min_dim_with_indexes<const D: usize>(
|
||||
/// [Tensor::min_dim_with_indices](Tensor::min_dim_with_indices) function, which is more high-level and designed for public use.
|
||||
fn min_dim_with_indices<const D: usize>(
|
||||
tensor: Self::Primitive<D>,
|
||||
dim: usize,
|
||||
) -> (Self::Primitive<D>, B::IntTensorPrimitive<D>);
|
||||
|
@ -1441,37 +1441,37 @@ impl<B: Backend> Numeric<B> for Int {
|
|||
B::int_mask_fill(tensor, mask.primitive, value)
|
||||
}
|
||||
|
||||
fn index_select<const D: usize>(
|
||||
fn select<const D: usize>(
|
||||
tensor: Self::Primitive<D>,
|
||||
dim: usize,
|
||||
indexes: Tensor<B, 1, Int>,
|
||||
indices: Tensor<B, 1, Int>,
|
||||
) -> Self::Primitive<D> {
|
||||
B::int_index_select_dim(tensor, dim, indexes.primitive)
|
||||
B::int_select(tensor, dim, indices.primitive)
|
||||
}
|
||||
|
||||
fn index_select_assign<const D1: usize, const D2: usize>(
|
||||
tensor: Self::Primitive<D1>,
|
||||
fn select_assign<const D: usize>(
|
||||
tensor: Self::Primitive<D>,
|
||||
dim: usize,
|
||||
indexes: Tensor<B, 1, Int>,
|
||||
values: Self::Primitive<D2>,
|
||||
) -> Self::Primitive<D1> {
|
||||
B::int_index_select_dim_assign(tensor, dim, indexes.primitive, values)
|
||||
indices: Tensor<B, 1, Int>,
|
||||
values: Self::Primitive<D>,
|
||||
) -> Self::Primitive<D> {
|
||||
B::int_select_assign(tensor, dim, indices.primitive, values)
|
||||
}
|
||||
fn gather<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: Self::Primitive<D>,
|
||||
indexes: Tensor<B, D, Int>,
|
||||
indices: Tensor<B, D, Int>,
|
||||
) -> Self::Primitive<D> {
|
||||
B::int_gather(dim, tensor, indexes.primitive)
|
||||
B::int_gather(dim, tensor, indices.primitive)
|
||||
}
|
||||
|
||||
fn scatter<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: Self::Primitive<D>,
|
||||
indexes: Tensor<B, D, Int>,
|
||||
indices: Tensor<B, D, Int>,
|
||||
values: Self::Primitive<D>,
|
||||
) -> Self::Primitive<D> {
|
||||
B::int_scatter(dim, tensor, indexes.primitive, values)
|
||||
B::int_scatter(dim, tensor, indices.primitive, values)
|
||||
}
|
||||
|
||||
fn argmax<const D: usize>(
|
||||
|
@ -1496,11 +1496,11 @@ impl<B: Backend> Numeric<B> for Int {
|
|||
B::int_max_dim(tensor, dim)
|
||||
}
|
||||
|
||||
fn max_dim_with_indexes<const D: usize>(
|
||||
fn max_dim_with_indices<const D: usize>(
|
||||
tensor: Self::Primitive<D>,
|
||||
dim: usize,
|
||||
) -> (Self::Primitive<D>, <B as Backend>::IntTensorPrimitive<D>) {
|
||||
B::int_max_dim_with_indexes(tensor, dim)
|
||||
B::int_max_dim_with_indices(tensor, dim)
|
||||
}
|
||||
|
||||
fn min<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<1> {
|
||||
|
@ -1511,11 +1511,11 @@ impl<B: Backend> Numeric<B> for Int {
|
|||
B::int_min_dim(tensor, dim)
|
||||
}
|
||||
|
||||
fn min_dim_with_indexes<const D: usize>(
|
||||
fn min_dim_with_indices<const D: usize>(
|
||||
tensor: Self::Primitive<D>,
|
||||
dim: usize,
|
||||
) -> (Self::Primitive<D>, <B as Backend>::IntTensorPrimitive<D>) {
|
||||
B::int_min_dim_with_indexes(tensor, dim)
|
||||
B::int_min_dim_with_indices(tensor, dim)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1662,38 +1662,38 @@ impl<B: Backend> Numeric<B> for Float {
|
|||
B::mask_fill(tensor, mask.primitive, value)
|
||||
}
|
||||
|
||||
fn index_select<const D: usize>(
|
||||
fn select<const D: usize>(
|
||||
tensor: Self::Primitive<D>,
|
||||
dim: usize,
|
||||
indexes: Tensor<B, 1, Int>,
|
||||
indices: Tensor<B, 1, Int>,
|
||||
) -> Self::Primitive<D> {
|
||||
B::index_select(tensor, dim, indexes.primitive)
|
||||
B::select(tensor, dim, indices.primitive)
|
||||
}
|
||||
|
||||
fn index_select_assign<const D1: usize, const D2: usize>(
|
||||
tensor: Self::Primitive<D1>,
|
||||
fn select_assign<const D: usize>(
|
||||
tensor: Self::Primitive<D>,
|
||||
dim: usize,
|
||||
indexes: Tensor<B, 1, Int>,
|
||||
values: Self::Primitive<D2>,
|
||||
) -> Self::Primitive<D1> {
|
||||
B::index_select_assign(tensor, dim, indexes.primitive, values)
|
||||
indices: Tensor<B, 1, Int>,
|
||||
values: Self::Primitive<D>,
|
||||
) -> Self::Primitive<D> {
|
||||
B::select_assign(tensor, dim, indices.primitive, values)
|
||||
}
|
||||
|
||||
fn gather<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: Self::Primitive<D>,
|
||||
indexes: Tensor<B, D, Int>,
|
||||
indices: Tensor<B, D, Int>,
|
||||
) -> Self::Primitive<D> {
|
||||
B::gather(dim, tensor, indexes.primitive)
|
||||
B::gather(dim, tensor, indices.primitive)
|
||||
}
|
||||
|
||||
fn scatter<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: Self::Primitive<D>,
|
||||
indexes: Tensor<B, D, Int>,
|
||||
indices: Tensor<B, D, Int>,
|
||||
values: Self::Primitive<D>,
|
||||
) -> Self::Primitive<D> {
|
||||
B::scatter(dim, tensor, indexes.primitive, values)
|
||||
B::scatter(dim, tensor, indices.primitive, values)
|
||||
}
|
||||
|
||||
fn argmax<const D: usize>(
|
||||
|
@ -1718,11 +1718,11 @@ impl<B: Backend> Numeric<B> for Float {
|
|||
B::max_dim(tensor, dim)
|
||||
}
|
||||
|
||||
fn max_dim_with_indexes<const D: usize>(
|
||||
fn max_dim_with_indices<const D: usize>(
|
||||
tensor: Self::Primitive<D>,
|
||||
dim: usize,
|
||||
) -> (Self::Primitive<D>, <B as Backend>::IntTensorPrimitive<D>) {
|
||||
B::max_dim_with_indexes(tensor, dim)
|
||||
B::max_dim_with_indices(tensor, dim)
|
||||
}
|
||||
|
||||
fn min<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<1> {
|
||||
|
@ -1733,11 +1733,11 @@ impl<B: Backend> Numeric<B> for Float {
|
|||
B::min_dim(tensor, dim)
|
||||
}
|
||||
|
||||
fn min_dim_with_indexes<const D: usize>(
|
||||
fn min_dim_with_indices<const D: usize>(
|
||||
tensor: Self::Primitive<D>,
|
||||
dim: usize,
|
||||
) -> (Self::Primitive<D>, <B as Backend>::IntTensorPrimitive<D>) {
|
||||
B::min_dim_with_indexes(tensor, dim)
|
||||
B::min_dim_with_indices(tensor, dim)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -5,11 +5,11 @@ use crate::{
|
|||
};
|
||||
|
||||
/// Applies the [embedding module](crate::ops::ModuleOps::embedding).
|
||||
pub fn embedding<B>(weights: Tensor<B, 2>, indexes: Tensor<B, 2, Int>) -> Tensor<B, 3>
|
||||
pub fn embedding<B>(weights: Tensor<B, 2>, indices: Tensor<B, 2, Int>) -> Tensor<B, 3>
|
||||
where
|
||||
B: Backend,
|
||||
{
|
||||
Tensor::new(B::embedding(weights.primitive, indexes.primitive))
|
||||
Tensor::new(B::embedding(weights.primitive, indices.primitive))
|
||||
}
|
||||
|
||||
/// Applies a [1D convolution](crate::ops::ModuleOps::conv2d).
|
||||
|
@ -123,8 +123,8 @@ where
|
|||
Tensor::new(B::avg_pool1d(x.primitive, kernel_size, stride, padding))
|
||||
}
|
||||
|
||||
/// Applies a [2D max pooling with indexes](crate::ops::ModuleOps::max_pool2d_with_indexes).
|
||||
pub fn max_pool2d_with_indexes<B>(
|
||||
/// Applies a [2D max pooling with indices](crate::ops::ModuleOps::max_pool2d_with_indices).
|
||||
pub fn max_pool2d_with_indices<B>(
|
||||
x: Tensor<B, 4>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
|
@ -133,7 +133,7 @@ pub fn max_pool2d_with_indexes<B>(
|
|||
where
|
||||
B: Backend,
|
||||
{
|
||||
let output = B::max_pool2d_with_indexes(x.primitive, kernel_size, stride, padding);
|
||||
let output = B::max_pool2d_with_indices(x.primitive, kernel_size, stride, padding);
|
||||
|
||||
(Tensor::new(output.output), Tensor::new(output.indexes))
|
||||
(Tensor::new(output.output), Tensor::new(output.indices))
|
||||
}
|
||||
|
|
|
@ -114,35 +114,35 @@ pub trait BoolTensorOps<B: Backend> {
|
|||
shape: Shape<D2>,
|
||||
) -> B::BoolTensorPrimitive<D2>;
|
||||
|
||||
/// Gets the values from the tensor for the given indexes.
|
||||
/// Gets the values from the tensor for the given ranges.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor.
|
||||
/// * `indexes` - The indexes to get the values from.
|
||||
/// * `ranges` - The ranges to get the values from.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor with the values for the given indexes.
|
||||
fn bool_index<const D1: usize, const D2: usize>(
|
||||
/// The tensor with the values for the given ranges.
|
||||
fn bool_slice<const D1: usize, const D2: usize>(
|
||||
tensor: B::BoolTensorPrimitive<D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
ranges: [Range<usize>; D2],
|
||||
) -> B::BoolTensorPrimitive<D1>;
|
||||
|
||||
/// Sets the values in the tensor for the given indexes.
|
||||
/// Sets the values in the tensor for the given ranges.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor.
|
||||
/// * `indexes` - The indexes to set the values for.
|
||||
/// * `ranges` - The ranges to set the values for.
|
||||
/// * `value` - The values to set.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor with the values set for the given indexes.
|
||||
fn bool_index_assign<const D1: usize, const D2: usize>(
|
||||
/// The tensor with the values set for the given ranges.
|
||||
fn bool_slice_assign<const D1: usize, const D2: usize>(
|
||||
tensor: B::BoolTensorPrimitive<D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
ranges: [Range<usize>; D2],
|
||||
value: B::BoolTensorPrimitive<D1>,
|
||||
) -> B::BoolTensorPrimitive<D1>;
|
||||
|
||||
|
@ -169,7 +169,7 @@ pub trait BoolTensorOps<B: Backend> {
|
|||
shape.dims[dim] = times;
|
||||
|
||||
let mut i = 0;
|
||||
let indexes_select_all = [0; D].map(|_| {
|
||||
let ranges_select_all = [0; D].map(|_| {
|
||||
let start = 0;
|
||||
let end = shape.dims[i];
|
||||
i += 1;
|
||||
|
@ -178,9 +178,9 @@ pub trait BoolTensorOps<B: Backend> {
|
|||
|
||||
let mut tensor_output = Self::bool_empty(shape, &Self::bool_device(&tensor));
|
||||
for i in 0..times {
|
||||
let mut indexes = indexes_select_all.clone();
|
||||
indexes[dim] = i..i + 1;
|
||||
tensor_output = Self::bool_index_assign(tensor_output, indexes, tensor.clone());
|
||||
let mut ranges = ranges_select_all.clone();
|
||||
ranges[dim] = i..i + 1;
|
||||
tensor_output = Self::bool_slice_assign(tensor_output, ranges, tensor.clone());
|
||||
}
|
||||
|
||||
tensor_output
|
||||
|
|
|
@ -110,9 +110,9 @@ pub trait IntTensorOps<B: Backend> {
|
|||
/// # Returns
|
||||
///
|
||||
/// The elements at the given indices.
|
||||
fn int_index<const D1: usize, const D2: usize>(
|
||||
fn int_slice<const D1: usize, const D2: usize>(
|
||||
tensor: B::IntTensorPrimitive<D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
indices: [Range<usize>; D2],
|
||||
) -> B::IntTensorPrimitive<D1>;
|
||||
|
||||
/// Sets the element at the given indices.
|
||||
|
@ -125,7 +125,7 @@ pub trait IntTensorOps<B: Backend> {
|
|||
/// # Returns
|
||||
///
|
||||
/// The tensor with the element at the given indices set.
|
||||
fn int_index_assign<const D1: usize, const D2: usize>(
|
||||
fn int_slice_assign<const D1: usize, const D2: usize>(
|
||||
tensor: B::IntTensorPrimitive<D1>,
|
||||
indices: [Range<usize>; D2],
|
||||
value: B::IntTensorPrimitive<D1>,
|
||||
|
@ -204,15 +204,15 @@ pub trait IntTensorOps<B: Backend> {
|
|||
///
|
||||
/// * `tensor` - The tensor.
|
||||
/// * `dim` - The dimension to select from.
|
||||
/// * `indexes` - The indexes.
|
||||
/// * `indices` - The indices.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor with the selected elements.
|
||||
fn int_index_select_dim<const D: usize>(
|
||||
fn int_select<const D: usize>(
|
||||
tensor: B::IntTensorPrimitive<D>,
|
||||
dim: usize,
|
||||
indexes: B::IntTensorPrimitive<1>,
|
||||
indices: B::IntTensorPrimitive<1>,
|
||||
) -> B::IntTensorPrimitive<D>;
|
||||
|
||||
/// Assign the selected elements along the given dimension corresponding to the given indices
|
||||
|
@ -222,18 +222,18 @@ pub trait IntTensorOps<B: Backend> {
|
|||
///
|
||||
/// * `tensor` - The tensor.
|
||||
/// * `dim` - The dimension to select from.
|
||||
/// * `indexes` - The indexes.
|
||||
/// * `indices` - The indices.
|
||||
/// * `value` - The value.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor with the selected elements assigned to the given value.
|
||||
fn int_index_select_dim_assign<const D1: usize, const D2: usize>(
|
||||
tensor: B::IntTensorPrimitive<D1>,
|
||||
fn int_select_assign<const D: usize>(
|
||||
tensor: B::IntTensorPrimitive<D>,
|
||||
dim: usize,
|
||||
indexes: B::IntTensorPrimitive<1>,
|
||||
value: B::IntTensorPrimitive<D2>,
|
||||
) -> B::IntTensorPrimitive<D1>;
|
||||
indices: B::IntTensorPrimitive<1>,
|
||||
value: B::IntTensorPrimitive<D>,
|
||||
) -> B::IntTensorPrimitive<D>;
|
||||
|
||||
/// Repeats the tensor along the given dimension the given number of times.
|
||||
///
|
||||
|
@ -258,7 +258,7 @@ pub trait IntTensorOps<B: Backend> {
|
|||
shape.dims[dim] = times;
|
||||
|
||||
let mut i = 0;
|
||||
let indexes_select_all = [0; D].map(|_| {
|
||||
let indices_select_all = [0; D].map(|_| {
|
||||
let start = 0;
|
||||
let end = shape.dims[i];
|
||||
i += 1;
|
||||
|
@ -267,9 +267,9 @@ pub trait IntTensorOps<B: Backend> {
|
|||
|
||||
let mut tensor_output = Self::int_empty(shape, &Self::int_device(&tensor));
|
||||
for i in 0..times {
|
||||
let mut indexes = indexes_select_all.clone();
|
||||
indexes[dim] = i..i + 1;
|
||||
tensor_output = Self::int_index_assign(tensor_output, indexes, tensor.clone());
|
||||
let mut indices = indices_select_all.clone();
|
||||
indices[dim] = i..i + 1;
|
||||
tensor_output = Self::int_slice_assign(tensor_output, indices, tensor.clone());
|
||||
}
|
||||
|
||||
tensor_output
|
||||
|
@ -725,7 +725,7 @@ pub trait IntTensorOps<B: Backend> {
|
|||
/// # Returns
|
||||
///
|
||||
/// The maximum elements and corresponding indices along the dimension.
|
||||
fn int_max_dim_with_indexes<const D: usize>(
|
||||
fn int_max_dim_with_indices<const D: usize>(
|
||||
tensor: B::IntTensorPrimitive<D>,
|
||||
dim: usize,
|
||||
) -> (B::IntTensorPrimitive<D>, B::IntTensorPrimitive<D>) {
|
||||
|
@ -780,13 +780,13 @@ pub trait IntTensorOps<B: Backend> {
|
|||
/// # Returns
|
||||
///
|
||||
/// The minimum elements and corresponding indices along the dimension.
|
||||
fn int_min_dim_with_indexes<const D: usize>(
|
||||
fn int_min_dim_with_indices<const D: usize>(
|
||||
tensor: B::IntTensorPrimitive<D>,
|
||||
dim: usize,
|
||||
) -> (B::IntTensorPrimitive<D>, B::IntTensorPrimitive<D>) {
|
||||
let index = B::int_argmin(tensor.clone(), dim);
|
||||
let values = B::int_gather(D - 1, tensor, index.clone());
|
||||
let indices = B::int_argmin(tensor.clone(), dim);
|
||||
let values = B::int_gather(D - 1, tensor, indices.clone());
|
||||
|
||||
(values, index)
|
||||
(values, indices)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,14 +21,14 @@ pub struct MaxPool2dBackward<B: Backend> {
|
|||
pub x_grad: B::TensorPrimitive<4>,
|
||||
}
|
||||
|
||||
/// Results from [max_pool2d](ModuleOps::max_pool2d_with_indexes).
|
||||
/// Results from [max_pool2d](ModuleOps::max_pool2d_with_indices).
|
||||
#[derive(new)]
|
||||
pub struct MaxPool2dWithIndexes<B: Backend> {
|
||||
pub struct MaxPool2dWithIndices<B: Backend> {
|
||||
/// The output tensor.
|
||||
pub output: B::TensorPrimitive<4>,
|
||||
|
||||
/// The indexes tensor.
|
||||
pub indexes: B::IntTensorPrimitive<4>,
|
||||
/// The indices tensor.
|
||||
pub indices: B::IntTensorPrimitive<4>,
|
||||
}
|
||||
|
||||
/// Gradient computed during the backward pass for each tensor used by [conv1d](ModuleOps::conv1d).
|
||||
|
@ -86,20 +86,20 @@ pub trait ModuleOps<B: Backend> {
|
|||
/// # Arguments
|
||||
///
|
||||
/// * `weights` - The embedding weights.
|
||||
/// * `indexes` - The indexes tensor.
|
||||
/// * `indices` - The indices tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The output tensor.
|
||||
fn embedding(
|
||||
weights: B::TensorPrimitive<2>,
|
||||
indexes: B::IntTensorPrimitive<2>,
|
||||
indices: B::IntTensorPrimitive<2>,
|
||||
) -> B::TensorPrimitive<3> {
|
||||
let [batch_size, seq_length] = B::int_shape(&indexes).dims;
|
||||
let [batch_size, seq_length] = B::int_shape(&indices).dims;
|
||||
let [_, d_model] = B::shape(&weights).dims;
|
||||
|
||||
let indexes = B::int_reshape(indexes, Shape::new([batch_size * seq_length]));
|
||||
let output = B::index_select(weights, 0, indexes);
|
||||
let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length]));
|
||||
let output = B::select(weights, 0, indices);
|
||||
|
||||
B::reshape(output, Shape::new([batch_size, seq_length, d_model]))
|
||||
}
|
||||
|
@ -110,7 +110,7 @@ pub trait ModuleOps<B: Backend> {
|
|||
///
|
||||
/// * `weights` - The embedding weights.
|
||||
/// * `output_grad` - The output gradient.
|
||||
/// * `indexes` - The indexes tensor.
|
||||
/// * `indices` - The indices tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
|
@ -118,17 +118,17 @@ pub trait ModuleOps<B: Backend> {
|
|||
fn embedding_backward(
|
||||
weights: B::TensorPrimitive<2>,
|
||||
output_grad: B::TensorPrimitive<3>,
|
||||
indexes: B::IntTensorPrimitive<2>,
|
||||
indices: B::IntTensorPrimitive<2>,
|
||||
) -> B::TensorPrimitive<2> {
|
||||
let [batch_size, seq_length] = B::int_shape(&indexes).dims;
|
||||
let [batch_size, seq_length] = B::int_shape(&indices).dims;
|
||||
let [n_embeddings, d_model] = B::shape(&weights).dims;
|
||||
let device = B::device(&weights);
|
||||
|
||||
let indexes = B::int_reshape(indexes, Shape::new([batch_size * seq_length]));
|
||||
let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length]));
|
||||
let output_grad = B::reshape(output_grad, Shape::new([batch_size * seq_length, d_model]));
|
||||
let grad = B::zeros(Shape::new([n_embeddings, d_model]), &device);
|
||||
|
||||
B::index_select_assign(grad, 0, indexes, output_grad)
|
||||
B::select_assign(grad, 0, indices, output_grad)
|
||||
}
|
||||
|
||||
/// Two dimensional convolution.
|
||||
|
@ -263,24 +263,24 @@ pub trait ModuleOps<B: Backend> {
|
|||
padding: [usize; 2],
|
||||
) -> B::TensorPrimitive<4>;
|
||||
|
||||
/// Two dimensional max pooling with indexes.
|
||||
/// Two dimensional max pooling with indices.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// x: [batch_size, channels, height, width],
|
||||
fn max_pool2d_with_indexes(
|
||||
fn max_pool2d_with_indices(
|
||||
x: B::TensorPrimitive<4>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
) -> MaxPool2dWithIndexes<B>;
|
||||
/// Backward pass for the [max pooling 2d](ModuleOps::max_pool2d_with_indexes) operation.
|
||||
fn max_pool2d_with_indexes_backward(
|
||||
) -> MaxPool2dWithIndices<B>;
|
||||
/// Backward pass for the [max pooling 2d](ModuleOps::max_pool2d_with_indices) operation.
|
||||
fn max_pool2d_with_indices_backward(
|
||||
x: B::TensorPrimitive<4>,
|
||||
kernel_size: [usize; 2],
|
||||
stride: [usize; 2],
|
||||
padding: [usize; 2],
|
||||
output_grad: B::TensorPrimitive<4>,
|
||||
indexes: B::IntTensorPrimitive<4>,
|
||||
indices: B::IntTensorPrimitive<4>,
|
||||
) -> MaxPool2dBackward<B>;
|
||||
}
|
||||
|
|
|
@ -243,8 +243,8 @@ fn conv1d_weight_grad_groups<B: Backend>(
|
|||
let start_idx_co = g * increment_co;
|
||||
let end_idx_co = (g + 1) * increment_co;
|
||||
|
||||
let x = B::index(x_swapped.clone(), [start_idx_ci..end_idx_ci]);
|
||||
let grad = B::index(output_grad_swapped.clone(), [start_idx_co..end_idx_co]);
|
||||
let x = B::slice(x_swapped.clone(), [start_idx_ci..end_idx_ci]);
|
||||
let grad = B::slice(output_grad_swapped.clone(), [start_idx_co..end_idx_co]);
|
||||
let mut weight_grad_tmp = B::conv1d(
|
||||
x,
|
||||
grad,
|
||||
|
@ -252,7 +252,7 @@ fn conv1d_weight_grad_groups<B: Backend>(
|
|||
ConvOptions::new(options.dilation, options.padding, options.stride, 1),
|
||||
);
|
||||
weight_grad_tmp = B::swap_dims(weight_grad_tmp, 0, 1);
|
||||
weight_grad = B::index_assign(
|
||||
weight_grad = B::slice_assign(
|
||||
weight_grad,
|
||||
[start_idx_co..end_idx_co, 0..increment_ci, 0..kernel_size],
|
||||
weight_grad_tmp,
|
||||
|
@ -280,8 +280,8 @@ fn conv2d_weight_grad_groups<B: Backend>(
|
|||
let start_idx_co = g * increment_co;
|
||||
let end_idx_co = (g + 1) * increment_co;
|
||||
|
||||
let x = B::index(x_swapped.clone(), [start_idx_ci..end_idx_ci]);
|
||||
let grad = B::index(output_grad_swapped.clone(), [start_idx_co..end_idx_co]);
|
||||
let x = B::slice(x_swapped.clone(), [start_idx_ci..end_idx_ci]);
|
||||
let grad = B::slice(output_grad_swapped.clone(), [start_idx_co..end_idx_co]);
|
||||
let mut weight_grad_tmp = B::conv2d(
|
||||
x,
|
||||
grad,
|
||||
|
@ -289,7 +289,7 @@ fn conv2d_weight_grad_groups<B: Backend>(
|
|||
ConvOptions::new(options.dilation, options.padding, options.stride, 1),
|
||||
);
|
||||
weight_grad_tmp = B::swap_dims(weight_grad_tmp, 0, 1);
|
||||
weight_grad = B::index_assign(
|
||||
weight_grad = B::slice_assign(
|
||||
weight_grad,
|
||||
[
|
||||
start_idx_co..end_idx_co,
|
||||
|
@ -321,7 +321,7 @@ fn conv1d_weight_grad_no_groups<B: Backend>(
|
|||
let mut weight_grad = B::swap_dims(weight_grad_swapped, 0, 1);
|
||||
|
||||
if B::shape(&weight_grad) != weight_shape {
|
||||
weight_grad = B::index(
|
||||
weight_grad = B::slice(
|
||||
weight_grad,
|
||||
[
|
||||
0..weight_shape.dims[0],
|
||||
|
@ -350,7 +350,7 @@ fn conv2d_weight_grad_no_groups<B: Backend>(
|
|||
let mut weight_grad = B::swap_dims(weight_grad_swapped, 0, 1);
|
||||
|
||||
if B::shape(&weight_grad) != weight_shape {
|
||||
weight_grad = B::index(
|
||||
weight_grad = B::slice(
|
||||
weight_grad,
|
||||
[
|
||||
0..weight_shape.dims[0],
|
||||
|
|
|
@ -181,7 +181,7 @@ pub trait TensorOps<B: Backend> {
|
|||
shape.dims[dim] = times;
|
||||
|
||||
let mut i = 0;
|
||||
let indexes_select_all = [0; D].map(|_| {
|
||||
let indices_select_all = [0; D].map(|_| {
|
||||
let start = 0;
|
||||
let end = shape.dims[i];
|
||||
i += 1;
|
||||
|
@ -190,9 +190,9 @@ pub trait TensorOps<B: Backend> {
|
|||
|
||||
let mut tensor_output = B::empty(shape, &B::device(&tensor));
|
||||
for i in 0..times {
|
||||
let mut indexes = indexes_select_all.clone();
|
||||
indexes[dim] = i..i + 1;
|
||||
tensor_output = B::index_assign(tensor_output, indexes, tensor.clone());
|
||||
let mut indices = indices_select_all.clone();
|
||||
indices[dim] = i..i + 1;
|
||||
tensor_output = B::slice_assign(tensor_output, indices, tensor.clone());
|
||||
}
|
||||
|
||||
tensor_output
|
||||
|
@ -380,7 +380,7 @@ pub trait TensorOps<B: Backend> {
|
|||
///
|
||||
/// * `dim` - The dimension to gather from.
|
||||
/// * `tensor` - The tensor to gather from.
|
||||
/// * `indexes` - The indexes to gather.
|
||||
/// * `indices` - The indices to gather.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
|
@ -388,7 +388,7 @@ pub trait TensorOps<B: Backend> {
|
|||
fn gather<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: B::TensorPrimitive<D>,
|
||||
indexes: B::IntTensorPrimitive<D>,
|
||||
indices: B::IntTensorPrimitive<D>,
|
||||
) -> B::TensorPrimitive<D>;
|
||||
|
||||
/// Scatter elements into a tensor.
|
||||
|
@ -397,7 +397,7 @@ pub trait TensorOps<B: Backend> {
|
|||
///
|
||||
/// * `dim` - The dimension to scatter into.
|
||||
/// * `tensor` - The tensor to scatter into.
|
||||
/// * `indexes` - The indexes to scatter into.
|
||||
/// * `indices` - The indices to scatter into.
|
||||
/// * `value` - The value to scatter.
|
||||
///
|
||||
/// # Returns
|
||||
|
@ -406,76 +406,76 @@ pub trait TensorOps<B: Backend> {
|
|||
fn scatter<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: B::TensorPrimitive<D>,
|
||||
indexes: B::IntTensorPrimitive<D>,
|
||||
indices: B::IntTensorPrimitive<D>,
|
||||
value: B::TensorPrimitive<D>,
|
||||
) -> B::TensorPrimitive<D>;
|
||||
|
||||
/// Select tensor elements along the given dimension corresponding for the given indexes.
|
||||
/// Select tensor elements along the given dimension corresponding for the given indices.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to select from.
|
||||
/// * `dim` - The dimension to select from.
|
||||
/// * `indexes` - The indexes to select.
|
||||
/// * `indices` - The indices to select.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The selected elements.
|
||||
fn index_select<const D: usize>(
|
||||
fn select<const D: usize>(
|
||||
tensor: B::TensorPrimitive<D>,
|
||||
dim: usize,
|
||||
indexes: B::IntTensorPrimitive<1>,
|
||||
indices: B::IntTensorPrimitive<1>,
|
||||
) -> B::TensorPrimitive<D>;
|
||||
|
||||
/// Assign the selected elements along the given dimension corresponding for the given indexes
|
||||
/// Assign the selected elements along the given dimension corresponding for the given indices
|
||||
/// to the given value.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to select from.
|
||||
/// * `dim` - The dimension to select from.
|
||||
/// * `indexes` - The indexes to select.
|
||||
/// * `indices` - The indices to select.
|
||||
/// * `value` - The value to assign.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor with the selected elements assigned to the given value.
|
||||
fn index_select_assign<const D1: usize, const D2: usize>(
|
||||
tensor: B::TensorPrimitive<D1>,
|
||||
fn select_assign<const D: usize>(
|
||||
tensor: B::TensorPrimitive<D>,
|
||||
dim: usize,
|
||||
indexes: B::IntTensorPrimitive<1>,
|
||||
value: B::TensorPrimitive<D2>,
|
||||
) -> B::TensorPrimitive<D1>;
|
||||
indices: B::IntTensorPrimitive<1>,
|
||||
value: B::TensorPrimitive<D>,
|
||||
) -> B::TensorPrimitive<D>;
|
||||
|
||||
/// Select tensor elements corresponding for the given indexes.
|
||||
/// Select tensor elements corresponding for the given ranges.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to select from.
|
||||
/// * `indexes` - The indexes to select.
|
||||
/// * `ranges` - The ranges to select.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The selected elements in a new tensor.
|
||||
fn index<const D1: usize, const D2: usize>(
|
||||
fn slice<const D1: usize, const D2: usize>(
|
||||
tensor: B::TensorPrimitive<D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
ranges: [Range<usize>; D2],
|
||||
) -> B::TensorPrimitive<D1>;
|
||||
|
||||
/// Assign the selected elements corresponding for the given indexes to the given value.
|
||||
/// Assign the selected elements corresponding for the given ranges to the given value.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to select from.
|
||||
/// * `indexes` - The indexes to select.
|
||||
/// * `ranges` - The ranges to select.
|
||||
/// * `value` - The value to assign.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor with the selected elements assigned to the given value.
|
||||
fn index_assign<const D1: usize, const D2: usize>(
|
||||
fn slice_assign<const D1: usize, const D2: usize>(
|
||||
tensor: B::TensorPrimitive<D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
ranges: [Range<usize>; D2],
|
||||
value: B::TensorPrimitive<D1>,
|
||||
) -> B::TensorPrimitive<D1>;
|
||||
|
||||
|
@ -872,7 +872,7 @@ pub trait TensorOps<B: Backend> {
|
|||
dim: usize,
|
||||
) -> B::TensorPrimitive<D>;
|
||||
|
||||
/// Gets the indexes of the maximum elements of a tensor along an axis.
|
||||
/// Gets the indices of the maximum elements of a tensor along an axis.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
|
@ -881,13 +881,13 @@ pub trait TensorOps<B: Backend> {
|
|||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor with the indexes of the maximum elements of `tensor` along `dim`.
|
||||
/// A tensor with the indices of the maximum elements of `tensor` along `dim`.
|
||||
fn argmax<const D: usize>(
|
||||
tensor: B::TensorPrimitive<D>,
|
||||
dim: usize,
|
||||
) -> B::IntTensorPrimitive<D>;
|
||||
|
||||
/// Gets the indexes of the minimum elements of a tensor along an axis.
|
||||
/// Gets the indices of the minimum elements of a tensor along an axis.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
|
@ -896,7 +896,7 @@ pub trait TensorOps<B: Backend> {
|
|||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor with the indexes of the minimum elements of `tensor` along `dim`.
|
||||
/// A tensor with the indices of the minimum elements of `tensor` along `dim`.
|
||||
fn argmin<const D: usize>(
|
||||
tensor: B::TensorPrimitive<D>,
|
||||
dim: usize,
|
||||
|
@ -934,7 +934,7 @@ pub trait TensorOps<B: Backend> {
|
|||
B::gather(D - 1, tensor, index)
|
||||
}
|
||||
|
||||
/// Gets the maximum elements of a tensor along an axis and their indexes.
|
||||
/// Gets the maximum elements of a tensor along an axis and their indices.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
|
@ -943,8 +943,8 @@ pub trait TensorOps<B: Backend> {
|
|||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tuple with the maximum elements of `tensor` along `dim` and their indexes.
|
||||
fn max_dim_with_indexes<const D: usize>(
|
||||
/// A tuple with the maximum elements of `tensor` along `dim` and their indices.
|
||||
fn max_dim_with_indices<const D: usize>(
|
||||
tensor: B::TensorPrimitive<D>,
|
||||
dim: usize,
|
||||
) -> (B::TensorPrimitive<D>, B::IntTensorPrimitive<D>) {
|
||||
|
@ -986,7 +986,7 @@ pub trait TensorOps<B: Backend> {
|
|||
B::gather(D - 1, tensor, index)
|
||||
}
|
||||
|
||||
/// Gets the minimum elements of a tensor along an axis and their indexes.
|
||||
/// Gets the minimum elements of a tensor along an axis and their indices.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
|
@ -995,8 +995,8 @@ pub trait TensorOps<B: Backend> {
|
|||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tuple with the minimum elements of `tensor` along `dim` and their indexes.
|
||||
fn min_dim_with_indexes<const D: usize>(
|
||||
/// A tuple with the minimum elements of `tensor` along `dim` and their indices.
|
||||
fn min_dim_with_indices<const D: usize>(
|
||||
tensor: B::TensorPrimitive<D>,
|
||||
dim: usize,
|
||||
) -> (B::TensorPrimitive<D>, B::IntTensorPrimitive<D>) {
|
||||
|
|
|
@ -37,9 +37,9 @@ macro_rules! testgen_all {
|
|||
burn_tensor::testgen_log!();
|
||||
burn_tensor::testgen_sqrt!();
|
||||
burn_tensor::testgen_log1p!();
|
||||
burn_tensor::testgen_index!();
|
||||
burn_tensor::testgen_slice!();
|
||||
burn_tensor::testgen_gather_scatter!();
|
||||
burn_tensor::testgen_index_select!();
|
||||
burn_tensor::testgen_select!();
|
||||
burn_tensor::testgen_map_comparison!();
|
||||
burn_tensor::testgen_mask!();
|
||||
burn_tensor::testgen_matmul!();
|
||||
|
|
|
@ -6,11 +6,11 @@ mod tests {
|
|||
#[test]
|
||||
fn test_embedding_forward() {
|
||||
let weights = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let indexes = Data::from([[0, 1], [1, 1]]);
|
||||
let indices = Data::from([[0, 1], [1, 1]]);
|
||||
let weights = Tensor::<TestBackend, 2>::from_data(weights);
|
||||
let indexes = Tensor::<TestBackend, 2, Int>::from_data(indexes);
|
||||
let indices = Tensor::<TestBackend, 2, Int>::from_data(indices);
|
||||
|
||||
let output = embedding(weights, indexes);
|
||||
let output = embedding(weights, indices);
|
||||
let expected = Data::from([
|
||||
[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]],
|
||||
[[3.0, 4.0, 5.0], [3.0, 4.0, 5.0]],
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#[burn_tensor_testgen::testgen(module_max_pool2d)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::module::{max_pool2d, max_pool2d_with_indexes};
|
||||
use burn_tensor::module::{max_pool2d, max_pool2d_with_indices};
|
||||
use burn_tensor::{Data, Tensor};
|
||||
|
||||
#[test]
|
||||
|
@ -180,7 +180,7 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_pool2d_with_indexes() {
|
||||
fn test_max_pool2d_with_indices() {
|
||||
let batch_size = 1;
|
||||
let channels_in = 1;
|
||||
let kernel_size_1 = 2;
|
||||
|
@ -196,7 +196,7 @@ mod tests {
|
|||
[0.5416, 0.8602, 0.8129, 0.1662],
|
||||
[0.3358, 0.3059, 0.8293, 0.0990],
|
||||
]]]);
|
||||
let indexes = Data::<i64, 4>::from([[[
|
||||
let indices = Data::<i64, 4>::from([[[
|
||||
[0, 1, 1, 3, 3],
|
||||
[4, 4, 1, 7, 7],
|
||||
[4, 9, 9, 7, 7],
|
||||
|
@ -211,7 +211,7 @@ mod tests {
|
|||
[0.3358, 0.3358, 0.8293, 0.8293, 0.0990],
|
||||
]]]);
|
||||
|
||||
let (output, output_indexes) = max_pool2d_with_indexes(
|
||||
let (output, output_indices) = max_pool2d_with_indices(
|
||||
x,
|
||||
[kernel_size_1, kernel_size_2],
|
||||
[stride_1, stride_2],
|
||||
|
@ -219,7 +219,7 @@ mod tests {
|
|||
);
|
||||
|
||||
y.to_data().assert_approx_eq(&output.into_data(), 3);
|
||||
assert_eq!(indexes.value, output_indexes.into_data().value);
|
||||
assert_eq!(indices.value, output_indices.into_data().value);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -240,7 +240,7 @@ mod tests {
|
|||
[0.4384, 0.9963, 0.9698, 0.4988, 0.2609],
|
||||
[0.3391, 0.2230, 0.4610, 0.5365, 0.6880],
|
||||
]]]);
|
||||
let indexes = Data::<i64, 4>::from([[[
|
||||
let indices = Data::<i64, 4>::from([[[
|
||||
[5, 7, 3],
|
||||
[5, 7, 3],
|
||||
[5, 16, 3],
|
||||
|
@ -256,7 +256,7 @@ mod tests {
|
|||
[0.4384, 0.9963, 0.688],
|
||||
[0.4384, 0.9963, 0.688],
|
||||
]]]);
|
||||
let (output, output_indexes) = max_pool2d_with_indexes(
|
||||
let (output, output_indices) = max_pool2d_with_indices(
|
||||
x,
|
||||
[kernel_size_1, kernel_size_2],
|
||||
[stride_1, stride_2],
|
||||
|
@ -264,6 +264,6 @@ mod tests {
|
|||
);
|
||||
|
||||
y.to_data().assert_approx_eq(&output.into_data(), 3);
|
||||
assert_eq!(indexes.value, output_indexes.into_data().value);
|
||||
assert_eq!(indices.value, output_indices.into_data().value);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,9 +6,9 @@ mod tests {
|
|||
#[test]
|
||||
fn should_gather_1d_dim0() {
|
||||
let tensor = TestTensor::from_floats([0.0, 1.0, 2.0]);
|
||||
let indexes = TestTensorInt::from_ints([1, 1, 0, 1, 2]);
|
||||
let indices = TestTensorInt::from_ints([1, 1, 0, 1, 2]);
|
||||
|
||||
let output = tensor.gather(0, indexes);
|
||||
let output = tensor.gather(0, indices);
|
||||
|
||||
assert_eq!(output.into_data(), Data::from([1.0, 1.0, 0.0, 1.0, 2.0]));
|
||||
}
|
||||
|
@ -16,9 +16,9 @@ mod tests {
|
|||
#[test]
|
||||
fn should_gather_2d_dim0() {
|
||||
let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let indexes = TestTensorInt::from_ints([[0, 1, 0], [1, 0, 1]]);
|
||||
let indices = TestTensorInt::from_ints([[0, 1, 0], [1, 0, 1]]);
|
||||
|
||||
let output = tensor.gather(0, indexes);
|
||||
let output = tensor.gather(0, indices);
|
||||
|
||||
assert_eq!(
|
||||
output.into_data(),
|
||||
|
@ -29,9 +29,9 @@ mod tests {
|
|||
#[test]
|
||||
fn should_gather_2d_dim1() {
|
||||
let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let indexes = TestTensorInt::from_ints([[2, 1, 0, 0], [2, 0, 1, 2]]);
|
||||
let indices = TestTensorInt::from_ints([[2, 1, 0, 0], [2, 0, 1, 2]]);
|
||||
|
||||
let output = tensor.gather(1, indexes);
|
||||
let output = tensor.gather(1, indices);
|
||||
|
||||
assert_eq!(
|
||||
output.into_data(),
|
||||
|
@ -45,9 +45,9 @@ mod tests {
|
|||
[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]],
|
||||
[[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]],
|
||||
]);
|
||||
let indexes = TestTensorInt::from_ints([[[1, 0, 0], [0, 1, 0]], [[0, 0, 1], [0, 1, 1]]]);
|
||||
let indices = TestTensorInt::from_ints([[[1, 0, 0], [0, 1, 0]], [[0, 0, 1], [0, 1, 1]]]);
|
||||
|
||||
let output = tensor.gather(1, indexes);
|
||||
let output = tensor.gather(1, indices);
|
||||
|
||||
assert_eq!(
|
||||
output.into_data(),
|
||||
|
@ -61,9 +61,9 @@ mod tests {
|
|||
#[test]
|
||||
fn should_gather_2d_only_1dim() {
|
||||
let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let indexes = TestTensorInt::from_ints([[1, 2]]).reshape([2, 1]);
|
||||
let indices = TestTensorInt::from_ints([[1, 2]]).reshape([2, 1]);
|
||||
|
||||
let output = tensor.gather(1, indexes);
|
||||
let output = tensor.gather(1, indices);
|
||||
|
||||
assert_eq!(output.into_data(), Data::from([[1.0], [5.0]]));
|
||||
}
|
||||
|
@ -72,9 +72,9 @@ mod tests {
|
|||
fn should_scatter_1d() {
|
||||
let tensor = TestTensor::from_floats([0.0, 0.0, 0.0]);
|
||||
let values = TestTensor::from_floats([5.0, 4.0, 3.0]);
|
||||
let indexes = TestTensorInt::from_ints([1, 0, 2]);
|
||||
let indices = TestTensorInt::from_ints([1, 0, 2]);
|
||||
|
||||
let output = tensor.scatter(0, indexes, values);
|
||||
let output = tensor.scatter(0, indices, values);
|
||||
|
||||
assert_eq!(output.into_data(), Data::from([4.0, 5.0, 3.0]));
|
||||
}
|
||||
|
@ -83,9 +83,9 @@ mod tests {
|
|||
fn should_scatter_2d_dim0() {
|
||||
let tensor = TestTensor::from_floats([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]);
|
||||
let values = TestTensor::from_floats([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
|
||||
let indexes = TestTensorInt::from_ints([[1, 0, 1], [1, 1, 0]]);
|
||||
let indices = TestTensorInt::from_ints([[1, 0, 1], [1, 1, 0]]);
|
||||
|
||||
let output = tensor.scatter(0, indexes, values);
|
||||
let output = tensor.scatter(0, indices, values);
|
||||
|
||||
assert_eq!(
|
||||
output.into_data(),
|
||||
|
@ -97,9 +97,9 @@ mod tests {
|
|||
fn should_scatter_2d_dim1() {
|
||||
let tensor = TestTensor::from_floats([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]);
|
||||
let values = TestTensor::from_floats([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
|
||||
let indexes = TestTensorInt::from_ints([[1, 0, 2], [1, 2, 0]]);
|
||||
let indices = TestTensorInt::from_ints([[1, 0, 2], [1, 2, 0]]);
|
||||
|
||||
let output = tensor.scatter(1, indexes, values);
|
||||
let output = tensor.scatter(1, indices, values);
|
||||
|
||||
assert_eq!(
|
||||
output.into_data(),
|
||||
|
@ -117,9 +117,9 @@ mod tests {
|
|||
[[12.0, 13.0, 14.0], [15.0, 16.0, 17.0]],
|
||||
[[18.0, 19.0, 20.0], [21.0, 22.0, 23.0]],
|
||||
]);
|
||||
let indexes = TestTensorInt::from_ints([[[1, 0, 0], [0, 1, 0]], [[0, 0, 1], [0, 1, 1]]]);
|
||||
let indices = TestTensorInt::from_ints([[[1, 0, 0], [0, 1, 0]], [[0, 0, 1], [0, 1, 1]]]);
|
||||
|
||||
let output = tensor.scatter(1, indexes, values);
|
||||
let output = tensor.scatter(1, indices, values);
|
||||
|
||||
assert_eq!(
|
||||
output.into_data(),
|
||||
|
|
|
@ -14,10 +14,10 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_dim_with_indexes_2d() {
|
||||
fn test_max_dim_with_indices_2d() {
|
||||
let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
|
||||
let (output_actual, index_actual) = tensor.max_dim_with_indexes(1);
|
||||
let (output_actual, index_actual) = tensor.max_dim_with_indices(1);
|
||||
|
||||
let output_expected = Data::from([[2.], [5.]]);
|
||||
let index_expected = Data::from([[2], [2]]);
|
||||
|
@ -37,10 +37,10 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn test_min_dim_with_indexes_2d() {
|
||||
fn test_min_dim_with_indices_2d() {
|
||||
let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
|
||||
let (output_actual, index_actual) = tensor.min_dim_with_indexes(1);
|
||||
let (output_actual, index_actual) = tensor.min_dim_with_indices(1);
|
||||
|
||||
let output_expected = Data::from([[0.], [3.]]);
|
||||
let index_expected = Data::from([[0], [0]]);
|
||||
|
|
|
@ -9,8 +9,6 @@ mod erf;
|
|||
mod exp;
|
||||
mod flatten;
|
||||
mod gather_scatter;
|
||||
mod index;
|
||||
mod index_select;
|
||||
mod log;
|
||||
mod log1p;
|
||||
mod map_comparison;
|
||||
|
@ -22,7 +20,9 @@ mod neg;
|
|||
mod powf;
|
||||
mod repeat;
|
||||
mod reshape;
|
||||
mod select;
|
||||
mod sin;
|
||||
mod slice;
|
||||
mod sqrt;
|
||||
mod squeeze;
|
||||
mod sub;
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
#[burn_tensor_testgen::testgen(index_select)]
|
||||
#[burn_tensor_testgen::testgen(select)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::{Data, Tensor};
|
||||
|
@ -6,9 +6,9 @@ mod tests {
|
|||
#[test]
|
||||
fn should_select_1d() {
|
||||
let tensor = TestTensor::from_data([0.0, 1.0, 2.0]);
|
||||
let indexes = TestTensorInt::from_data([1, 1, 0, 1, 2]);
|
||||
let indices = TestTensorInt::from_data([1, 1, 0, 1, 2]);
|
||||
|
||||
let output = tensor.index_select(0, indexes);
|
||||
let output = tensor.select(0, indices);
|
||||
|
||||
assert_eq!(output.into_data(), Data::from([1.0, 1.0, 0.0, 1.0, 2.0]));
|
||||
}
|
||||
|
@ -16,9 +16,9 @@ mod tests {
|
|||
#[test]
|
||||
fn should_select_2d_dim0_same_num_dim() {
|
||||
let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let indexes = TestTensorInt::from_data(([1, 0]));
|
||||
let indices = TestTensorInt::from_data(([1, 0]));
|
||||
|
||||
let output = tensor.index_select(0, indexes);
|
||||
let output = tensor.select(0, indices);
|
||||
|
||||
assert_eq!(
|
||||
output.into_data(),
|
||||
|
@ -29,9 +29,9 @@ mod tests {
|
|||
#[test]
|
||||
fn should_select_2d_dim0_more_num_dim() {
|
||||
let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let indexes = TestTensorInt::from_data([1, 0, 1, 1]);
|
||||
let indices = TestTensorInt::from_data([1, 0, 1, 1]);
|
||||
|
||||
let output = tensor.index_select(0, indexes);
|
||||
let output = tensor.select(0, indices);
|
||||
|
||||
assert_eq!(
|
||||
output.into_data(),
|
||||
|
@ -47,9 +47,9 @@ mod tests {
|
|||
#[test]
|
||||
fn should_select_2d_dim1() {
|
||||
let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let indexes = TestTensorInt::from_data([1, 1, 0, 1, 2]);
|
||||
let indices = TestTensorInt::from_data([1, 1, 0, 1, 2]);
|
||||
|
||||
let output = tensor.index_select(1, indexes);
|
||||
let output = tensor.select(1, indices);
|
||||
|
||||
assert_eq!(
|
||||
output.into_data(),
|
||||
|
@ -61,9 +61,9 @@ mod tests {
|
|||
fn should_select_assign_1d() {
|
||||
let tensor = TestTensor::from_data([0.0, 1.0, 2.0]);
|
||||
let values = TestTensor::from_data([5.0, 4.0, 3.0, 2.0, 1.0]);
|
||||
let indexes = TestTensorInt::from_data(Data::from([1, 1, 0, 1, 2]));
|
||||
let indices = TestTensorInt::from_data(Data::from([1, 1, 0, 1, 2]));
|
||||
|
||||
let output = tensor.index_select_assign(0, indexes, values);
|
||||
let output = tensor.select_assign(0, indices, values);
|
||||
|
||||
assert_eq!(output.into_data(), Data::from([3.0, 12.0, 3.0]));
|
||||
}
|
||||
|
@ -72,9 +72,9 @@ mod tests {
|
|||
fn should_select_assign_2d_dim0() {
|
||||
let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let values = TestTensor::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
|
||||
let indexes = TestTensorInt::from_data(Data::from([1, 0]));
|
||||
let indices = TestTensorInt::from_data(Data::from([1, 0]));
|
||||
|
||||
let output = tensor.index_select_assign(0, indexes, values);
|
||||
let output = tensor.select_assign(0, indices, values);
|
||||
|
||||
assert_eq!(
|
||||
output.into_data(),
|
||||
|
@ -86,9 +86,9 @@ mod tests {
|
|||
fn should_select_assign_2d_dim1() {
|
||||
let tensor = TestTensor::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let values = TestTensor::from_data([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
|
||||
let indexes = TestTensorInt::from_data(Data::from([1, 0, 2]));
|
||||
let indices = TestTensorInt::from_data(Data::from([1, 0, 2]));
|
||||
|
||||
let output = tensor.index_select_assign(1, indexes, values);
|
||||
let output = tensor.select_assign(1, indices, values);
|
||||
|
||||
assert_eq!(
|
||||
output.into_data(),
|
|
@ -1,94 +1,94 @@
|
|||
#[burn_tensor_testgen::testgen(index)]
|
||||
#[burn_tensor_testgen::testgen(slice)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::{Data, Tensor};
|
||||
|
||||
#[test]
|
||||
fn should_support_full_indexing_1d() {
|
||||
fn should_support_full_sliceing_1d() {
|
||||
let data = Data::from([0.0, 1.0, 2.0]);
|
||||
let tensor = Tensor::<TestBackend, 1>::from_data(data.clone());
|
||||
|
||||
let data_actual = tensor.index([0..3]).into_data();
|
||||
let data_actual = tensor.slice([0..3]).into_data();
|
||||
|
||||
assert_eq!(data, data_actual);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_partial_indexing_1d() {
|
||||
fn should_support_partial_sliceing_1d() {
|
||||
let data = Data::from([0.0, 1.0, 2.0]);
|
||||
let tensor = Tensor::<TestBackend, 1>::from_data(data);
|
||||
|
||||
let data_actual = tensor.index([1..3]).into_data();
|
||||
let data_actual = tensor.slice([1..3]).into_data();
|
||||
|
||||
let data_expected = Data::from([1.0, 2.0]);
|
||||
assert_eq!(data_expected, data_actual);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_full_indexing_2d() {
|
||||
fn should_support_full_sliceing_2d() {
|
||||
let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let tensor = Tensor::<TestBackend, 2>::from_data(data.clone());
|
||||
|
||||
let data_actual_1 = tensor.clone().index([0..2]).into_data();
|
||||
let data_actual_2 = tensor.index([0..2, 0..3]).into_data();
|
||||
let data_actual_1 = tensor.clone().slice([0..2]).into_data();
|
||||
let data_actual_2 = tensor.slice([0..2, 0..3]).into_data();
|
||||
|
||||
assert_eq!(data, data_actual_1);
|
||||
assert_eq!(data, data_actual_2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_partial_indexing_2d() {
|
||||
fn should_support_partial_sliceing_2d() {
|
||||
let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let tensor = Tensor::<TestBackend, 2>::from_data(data);
|
||||
|
||||
let data_actual = tensor.index([0..2, 0..2]).into_data();
|
||||
let data_actual = tensor.slice([0..2, 0..2]).into_data();
|
||||
|
||||
let data_expected = Data::from([[0.0, 1.0], [3.0, 4.0]]);
|
||||
assert_eq!(data_expected, data_actual);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_partial_indexing_3d() {
|
||||
fn should_support_partial_sliceing_3d() {
|
||||
let tensor = TestTensor::from_floats([
|
||||
[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]],
|
||||
[[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]],
|
||||
]);
|
||||
|
||||
let data_actual = tensor.index([1..2, 1..2, 0..2]).into_data();
|
||||
let data_actual = tensor.slice([1..2, 1..2, 0..2]).into_data();
|
||||
|
||||
let data_expected = Data::from([[[9.0, 10.0]]]);
|
||||
assert_eq!(data_expected, data_actual);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_partial_indexing_3d_non_continuous() {
|
||||
fn should_support_partial_sliceing_3d_non_continuous() {
|
||||
let tensor = TestTensor::from_floats([
|
||||
[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]],
|
||||
[[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]],
|
||||
]);
|
||||
|
||||
let data_actual = tensor.transpose().index([1..2, 1..2, 0..2]).into_data();
|
||||
let data_actual = tensor.transpose().slice([1..2, 1..2, 0..2]).into_data();
|
||||
|
||||
let data_expected = Data::from([[[7.0, 10.0]]]);
|
||||
assert_eq!(data_expected, data_actual);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_indexe_assign_1d() {
|
||||
fn should_support_slicee_assign_1d() {
|
||||
let data = Data::from([0.0, 1.0, 2.0]);
|
||||
let data_assigned = Data::from([10.0, 5.0]);
|
||||
|
||||
let tensor = Tensor::<TestBackend, 1>::from_data(data);
|
||||
let tensor_assigned = Tensor::<TestBackend, 1>::from_data(data_assigned);
|
||||
|
||||
let data_actual = tensor.index_assign([0..2], tensor_assigned).into_data();
|
||||
let data_actual = tensor.slice_assign([0..2], tensor_assigned).into_data();
|
||||
|
||||
let data_expected = Data::from([10.0, 5.0, 2.0]);
|
||||
assert_eq!(data_expected, data_actual);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_indexe_assign_2d() {
|
||||
fn should_support_slicee_assign_2d() {
|
||||
let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let data_assigned = Data::from([[10.0, 5.0]]);
|
||||
|
||||
|
@ -96,7 +96,7 @@ mod tests {
|
|||
let tensor_assigned = Tensor::<TestBackend, 2>::from_data(data_assigned);
|
||||
|
||||
let data_actual = tensor
|
||||
.index_assign([1..2, 0..2], tensor_assigned)
|
||||
.slice_assign([1..2, 0..2], tensor_assigned)
|
||||
.into_data();
|
||||
|
||||
let data_expected = Data::from([[0.0, 1.0, 2.0], [10.0, 5.0, 5.0]]);
|
|
@ -52,4 +52,4 @@ harness = false
|
|||
|
||||
[[bench]]
|
||||
name = "matmul"
|
||||
harness = false
|
||||
harness = false
|
||||
|
|
|
@ -35,6 +35,41 @@ macro_rules! kernel_wgsl {
|
|||
};
|
||||
}
|
||||
|
||||
kernel_wgsl!(ContinuousRaw, "../template/continuous.wgsl");
|
||||
|
||||
pub(crate) fn into_continuous<E: WgpuElement, const D: usize>(
|
||||
tensor: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
if tensor.is_continuous() {
|
||||
return tensor;
|
||||
}
|
||||
|
||||
let buffer = tensor
|
||||
.context
|
||||
.create_buffer(tensor.shape.num_elements() * core::mem::size_of::<E>());
|
||||
let output = WgpuTensor::new(tensor.context.clone(), tensor.shape.clone(), buffer);
|
||||
let info = build_info(&[&tensor, &output]);
|
||||
let info_buffer = tensor
|
||||
.context
|
||||
.create_buffer_with_data(bytemuck::cast_slice(&info));
|
||||
|
||||
let kernel = tensor
|
||||
.context
|
||||
.compile_static::<KernelSettings<ContinuousRaw, E, i32, 256, 1, 1>>();
|
||||
|
||||
tensor.context.execute(
|
||||
WorkGroup::new(
|
||||
f32::ceil(output.shape.num_elements() as f32 / 256_f32) as u32,
|
||||
1,
|
||||
1,
|
||||
),
|
||||
kernel,
|
||||
&[&tensor.buffer, &output.buffer, &info_buffer],
|
||||
);
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
/// Generates kernel source code by replacing some information using templating.
|
||||
pub struct KernelSettings<
|
||||
K: StaticKernel,
|
||||
|
|
|
@ -0,0 +1,80 @@
|
|||
use crate::{
|
||||
element::WgpuElement,
|
||||
kernel::{self, build_info, elemwise_workgroup, KernelSettings},
|
||||
kernel_wgsl,
|
||||
tensor::WgpuTensor,
|
||||
};
|
||||
|
||||
kernel_wgsl!(Gather, "../../template/index/gather.wgsl");
|
||||
|
||||
pub(crate) fn gather<E: WgpuElement, I: WgpuElement, const D: usize>(
|
||||
dim: usize,
|
||||
tensor: WgpuTensor<E, D>,
|
||||
indices: WgpuTensor<I, D>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
const WORKGROUP: usize = 32;
|
||||
|
||||
let shape_output = indices.shape.clone();
|
||||
let num_elems = shape_output.num_elements();
|
||||
let indices = kernel::into_continuous(indices);
|
||||
|
||||
let buffer = tensor
|
||||
.context
|
||||
.create_buffer(num_elems * core::mem::size_of::<E>());
|
||||
let output = WgpuTensor::new(tensor.context.clone(), shape_output, buffer);
|
||||
let mut info = build_info(&[&tensor, &output]);
|
||||
info.push(dim as u32);
|
||||
let info_buffer = tensor
|
||||
.context
|
||||
.create_buffer_with_data(bytemuck::cast_slice(&info));
|
||||
|
||||
let kernel = tensor
|
||||
.context
|
||||
.compile_static::<KernelSettings<Gather, E, i32, WORKGROUP, WORKGROUP, 1>>();
|
||||
|
||||
tensor.context.execute(
|
||||
elemwise_workgroup(num_elems, WORKGROUP),
|
||||
kernel,
|
||||
&[
|
||||
&tensor.buffer,
|
||||
&indices.buffer,
|
||||
&output.buffer,
|
||||
&info_buffer,
|
||||
],
|
||||
);
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tests::{ReferenceBackend, TestBackend};
|
||||
use burn_tensor::{backend::Backend, Distribution, Int, Tensor};
|
||||
|
||||
#[test]
|
||||
fn gather_should_work_with_multiple_workgroups() {
|
||||
TestBackend::seed(0);
|
||||
let tensor = Tensor::<TestBackend, 2>::random([6, 256], Distribution::Standard);
|
||||
let indices = Tensor::<TestBackend, 1, Int>::from_data(
|
||||
Tensor::<TestBackend, 1>::random([6 * 256], Distribution::Uniform(0., 256.))
|
||||
.into_data()
|
||||
.convert(),
|
||||
)
|
||||
.reshape([6, 256]);
|
||||
let tensor_ref = Tensor::<ReferenceBackend, 2>::from_data(tensor.to_data());
|
||||
let indices_ref =
|
||||
Tensor::<ReferenceBackend, 2, Int>::from_data(indices.to_data().convert());
|
||||
|
||||
let actual = Tensor::<TestBackend, 2>::from_primitive(gather(
|
||||
1,
|
||||
tensor.into_primitive(),
|
||||
indices.into_primitive(),
|
||||
));
|
||||
let expected = tensor_ref.gather(1, indices_ref);
|
||||
|
||||
expected
|
||||
.into_data()
|
||||
.assert_approx_eq(&actual.into_data(), 3);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,9 @@
|
|||
mod gather;
|
||||
mod scatter;
|
||||
mod select;
|
||||
mod slice;
|
||||
|
||||
pub use gather::*;
|
||||
pub use scatter::*;
|
||||
pub use select::*;
|
||||
pub use slice::*;
|
|
@ -0,0 +1,127 @@
|
|||
use crate::{
|
||||
element::WgpuElement,
|
||||
kernel::{self, build_info, elemwise_workgroup, KernelSettings},
|
||||
kernel_wgsl,
|
||||
tensor::WgpuTensor,
|
||||
};
|
||||
|
||||
kernel_wgsl!(Scatter, "../../template/index/scatter.wgsl");
|
||||
|
||||
pub(crate) fn scatter<E: WgpuElement, I: WgpuElement, const D: usize>(
|
||||
dim: usize,
|
||||
tensor: WgpuTensor<E, D>,
|
||||
indices: WgpuTensor<I, D>,
|
||||
value: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
const WORKGROUP: usize = 32;
|
||||
|
||||
let indices = kernel::into_continuous(indices);
|
||||
let tensor = kernel::into_continuous(tensor);
|
||||
let value = kernel::into_continuous(value);
|
||||
let tensor = match tensor.can_mut() {
|
||||
true => tensor,
|
||||
false => tensor.copy(),
|
||||
};
|
||||
let mut info = build_info(&[&tensor]);
|
||||
let mut strides = [0; D];
|
||||
let mut current = 1;
|
||||
let mut num_elems_per_workgroup = 1;
|
||||
|
||||
tensor
|
||||
.shape
|
||||
.dims
|
||||
.iter()
|
||||
.enumerate()
|
||||
.rev()
|
||||
.filter(|(index, _val)| *index != dim)
|
||||
.for_each(|(index, val)| {
|
||||
strides[index] = current;
|
||||
current *= val;
|
||||
num_elems_per_workgroup *= tensor.shape.dims[index];
|
||||
});
|
||||
|
||||
strides
|
||||
.into_iter()
|
||||
.for_each(|stride| info.push(stride as u32));
|
||||
|
||||
info.push(dim as u32);
|
||||
|
||||
let info_buffer = tensor
|
||||
.context
|
||||
.create_buffer_with_data(bytemuck::cast_slice(&info));
|
||||
|
||||
let kernel = tensor
|
||||
.context
|
||||
.compile_static::<KernelSettings<Scatter, E, i32, WORKGROUP, WORKGROUP, 1>>();
|
||||
|
||||
tensor.context.execute(
|
||||
elemwise_workgroup(num_elems_per_workgroup, WORKGROUP),
|
||||
kernel,
|
||||
&[&tensor.buffer, &indices.buffer, &value.buffer, &info_buffer],
|
||||
);
|
||||
|
||||
tensor
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tests::{ReferenceBackend, TestBackend};
|
||||
use burn_tensor::{backend::Backend, Distribution, Int, Tensor};
|
||||
|
||||
#[test]
|
||||
fn scatter_should_work_with_multiple_workgroups_2d_dim0() {
|
||||
same_as_reference(0, [256, 32]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scatter_should_work_with_multiple_workgroups_2d_dim1() {
|
||||
same_as_reference(1, [32, 256]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scatter_should_work_with_multiple_workgroups_3d_dim0() {
|
||||
same_as_reference(0, [256, 6, 6]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scatter_should_work_with_multiple_workgroups_3d_dim1() {
|
||||
same_as_reference(1, [6, 256, 6]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scatter_should_work_with_multiple_workgroups_3d_dim2() {
|
||||
same_as_reference(2, [6, 6, 256]);
|
||||
}
|
||||
|
||||
fn same_as_reference<const D: usize>(dim: usize, shape: [usize; D]) {
|
||||
TestBackend::seed(0);
|
||||
let tensor = Tensor::<TestBackend, D>::random(shape, Distribution::Standard);
|
||||
let value = Tensor::<TestBackend, D>::random(shape, Distribution::Standard);
|
||||
let indices = Tensor::<TestBackend, 1, Int>::from_data(
|
||||
Tensor::<TestBackend, 1>::random(
|
||||
[shape.iter().product()],
|
||||
Distribution::Uniform(0., shape[dim] as f32),
|
||||
)
|
||||
.into_data()
|
||||
.convert(),
|
||||
)
|
||||
.reshape(shape);
|
||||
let tensor_ref = Tensor::<ReferenceBackend, D>::from_data(tensor.to_data());
|
||||
let value_ref = Tensor::<ReferenceBackend, D>::from_data(value.to_data());
|
||||
let indices_ref =
|
||||
Tensor::<ReferenceBackend, D, Int>::from_data(indices.to_data().convert());
|
||||
|
||||
let actual = Tensor::<TestBackend, D>::from_primitive(scatter(
|
||||
dim,
|
||||
tensor.into_primitive(),
|
||||
indices.into_primitive(),
|
||||
value.into_primitive(),
|
||||
));
|
||||
let expected = tensor_ref.scatter(dim, indices_ref, value_ref);
|
||||
|
||||
expected
|
||||
.into_data()
|
||||
.assert_approx_eq(&actual.into_data(), 3);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,186 @@
|
|||
use crate::{
|
||||
element::WgpuElement,
|
||||
kernel::{build_info, elemwise_workgroup, KernelSettings},
|
||||
kernel_wgsl,
|
||||
tensor::WgpuTensor,
|
||||
};
|
||||
|
||||
kernel_wgsl!(IndexSelect, "../../template/index/select.wgsl");
|
||||
kernel_wgsl!(
|
||||
SelectAssignInplace,
|
||||
"../../template/index/select_assign_inplace.wgsl"
|
||||
);
|
||||
|
||||
pub(crate) fn select<E: WgpuElement, I: WgpuElement, const D: usize>(
|
||||
tensor: WgpuTensor<E, D>,
|
||||
dim: usize,
|
||||
indices: WgpuTensor<I, 1>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
const WORKGROUP: usize = 32;
|
||||
|
||||
let mut output_shape = tensor.shape.clone();
|
||||
output_shape.dims[dim] = indices.shape.dims[0];
|
||||
let num_elems = output_shape.num_elements();
|
||||
|
||||
let buffer = tensor
|
||||
.context
|
||||
.create_buffer(num_elems * std::mem::size_of::<E>());
|
||||
let output = WgpuTensor::new(tensor.context.clone(), output_shape, buffer);
|
||||
|
||||
let mut info = build_info(&[&tensor, &output]);
|
||||
info.push(dim as u32);
|
||||
|
||||
let info_buffer = tensor
|
||||
.context
|
||||
.create_buffer_with_data(bytemuck::cast_slice(&info));
|
||||
|
||||
let kernel = tensor
|
||||
.context
|
||||
.compile_static::<KernelSettings<IndexSelect, E, I, WORKGROUP, WORKGROUP, 1>>();
|
||||
|
||||
tensor.context.execute(
|
||||
elemwise_workgroup(num_elems, WORKGROUP),
|
||||
kernel,
|
||||
&[
|
||||
&tensor.buffer,
|
||||
&indices.buffer,
|
||||
&output.buffer,
|
||||
&info_buffer,
|
||||
],
|
||||
);
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
pub(crate) fn select_assign<E: WgpuElement, I: WgpuElement, const D: usize>(
|
||||
tensor: WgpuTensor<E, D>,
|
||||
dim: usize,
|
||||
indices: WgpuTensor<I, 1>,
|
||||
value: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
const WORKGROUP: usize = 32;
|
||||
|
||||
let tensor = match tensor.can_mut() {
|
||||
true => tensor,
|
||||
false => tensor.copy(),
|
||||
};
|
||||
|
||||
let mut info = build_info(&[&tensor, &value]);
|
||||
let mut strides = [0; D];
|
||||
let mut current = 1;
|
||||
let mut num_elems_per_workgroup = 1;
|
||||
|
||||
tensor
|
||||
.shape
|
||||
.dims
|
||||
.iter()
|
||||
.enumerate()
|
||||
.rev()
|
||||
.filter(|(index, _val)| *index != dim)
|
||||
.for_each(|(index, val)| {
|
||||
strides[index] = current;
|
||||
current *= val;
|
||||
num_elems_per_workgroup *= tensor.shape.dims[index];
|
||||
});
|
||||
|
||||
strides
|
||||
.into_iter()
|
||||
.for_each(|stride| info.push(stride as u32));
|
||||
|
||||
info.push(dim as u32);
|
||||
|
||||
let info_buffer = tensor
|
||||
.context
|
||||
.create_buffer_with_data(bytemuck::cast_slice(&info));
|
||||
|
||||
let kernel = tensor
|
||||
.context
|
||||
.compile_static::<KernelSettings<SelectAssignInplace, E, I, WORKGROUP, WORKGROUP, 1>>();
|
||||
|
||||
tensor.context.execute(
|
||||
elemwise_workgroup(num_elems_per_workgroup, WORKGROUP),
|
||||
kernel,
|
||||
&[&tensor.buffer, &indices.buffer, &value.buffer, &info_buffer],
|
||||
);
|
||||
|
||||
tensor
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tests::{ReferenceBackend, TestBackend};
|
||||
use burn_tensor::{backend::Backend, Distribution, Int, Tensor};
|
||||
|
||||
#[test]
|
||||
fn select_should_work_with_multiple_workgroups() {
|
||||
let tensor = Tensor::<TestBackend, 2>::random([6, 256], Distribution::Standard);
|
||||
let indices = Tensor::<TestBackend, 1, Int>::arange(0..100);
|
||||
let tensor_ref = Tensor::<ReferenceBackend, 2>::from_data(tensor.to_data());
|
||||
let indices_ref =
|
||||
Tensor::<ReferenceBackend, 1, Int>::from_data(indices.to_data().convert());
|
||||
|
||||
let actual = select(tensor.into_primitive(), 1, indices.into_primitive());
|
||||
let expected = tensor_ref.select(1, indices_ref);
|
||||
|
||||
expected.into_data().assert_approx_eq(
|
||||
&Tensor::<TestBackend, 2>::from_primitive(actual).into_data(),
|
||||
3,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn select_assign_should_work_with_multiple_workgroups_2d_dim0() {
|
||||
select_assign_same_as_ref(0, [256, 6]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn select_assign_should_work_with_multiple_workgroups_2d_dim1() {
|
||||
select_assign_same_as_ref(1, [6, 256]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn select_assign_should_work_with_multiple_workgroups_3d_dim0() {
|
||||
select_assign_same_as_ref(0, [256, 6, 6]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn select_assign_should_work_with_multiple_workgroups_3d_dim1() {
|
||||
select_assign_same_as_ref(1, [6, 256, 6]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn select_assign_should_work_with_multiple_workgroups_3d_dim2() {
|
||||
select_assign_same_as_ref(2, [6, 6, 256]);
|
||||
}
|
||||
|
||||
fn select_assign_same_as_ref<const D: usize>(dim: usize, shape: [usize; D]) {
|
||||
TestBackend::seed(0);
|
||||
let tensor = Tensor::<TestBackend, D>::random(shape, Distribution::Standard);
|
||||
let value = Tensor::<TestBackend, D>::random(shape, Distribution::Standard);
|
||||
let indices = Tensor::<TestBackend, 1, Int>::from_data(
|
||||
Tensor::<TestBackend, 1>::random(
|
||||
[shape[dim]],
|
||||
Distribution::Uniform(0., shape[dim] as f32),
|
||||
)
|
||||
.into_data()
|
||||
.convert(),
|
||||
);
|
||||
let tensor_ref = Tensor::<ReferenceBackend, D>::from_data(tensor.to_data());
|
||||
let value_ref = Tensor::<ReferenceBackend, D>::from_data(value.to_data());
|
||||
let indices_ref =
|
||||
Tensor::<ReferenceBackend, 1, Int>::from_data(indices.to_data().convert());
|
||||
|
||||
let actual = Tensor::<TestBackend, D>::from_primitive(select_assign(
|
||||
tensor.into_primitive(),
|
||||
dim,
|
||||
indices.into_primitive(),
|
||||
value.into_primitive(),
|
||||
));
|
||||
let expected = tensor_ref.select_assign(dim, indices_ref, value_ref);
|
||||
|
||||
expected
|
||||
.into_data()
|
||||
.assert_approx_eq(&actual.into_data(), 3);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,139 @@
|
|||
use crate::{
|
||||
element::WgpuElement,
|
||||
kernel::{build_info, elemwise_workgroup, KernelSettings},
|
||||
kernel_wgsl,
|
||||
tensor::WgpuTensor,
|
||||
};
|
||||
use burn_tensor::Shape;
|
||||
use std::ops::Range;
|
||||
|
||||
kernel_wgsl!(IndexRaw, "../../template/index/slice.wgsl");
|
||||
kernel_wgsl!(
|
||||
IndexAssignInplaceRaw,
|
||||
"../../template/index/slice_assign_inplace.wgsl"
|
||||
);
|
||||
|
||||
pub(crate) fn slice<E: WgpuElement, const D1: usize, const D2: usize>(
|
||||
tensor: WgpuTensor<E, D1>,
|
||||
indices: [Range<usize>; D2],
|
||||
) -> WgpuTensor<E, D1> {
|
||||
const WORKGROUP: usize = 32;
|
||||
|
||||
let mut dims = tensor.shape.dims;
|
||||
for i in 0..D2 {
|
||||
dims[i] = indices[i].end - indices[i].start;
|
||||
}
|
||||
let shape_output = Shape::new(dims);
|
||||
let num_elems = shape_output.num_elements();
|
||||
|
||||
let buffer = tensor
|
||||
.context
|
||||
.create_buffer(num_elems * core::mem::size_of::<E>());
|
||||
let output = WgpuTensor::new(tensor.context.clone(), shape_output, buffer);
|
||||
let mut info = build_info(&[&tensor, &output]);
|
||||
|
||||
for i in 0..D1 {
|
||||
let start = indices.get(i).map(|index| index.start).unwrap_or(0);
|
||||
info.push(start as u32);
|
||||
}
|
||||
|
||||
let info_buffer = tensor
|
||||
.context
|
||||
.create_buffer_with_data(bytemuck::cast_slice(&info));
|
||||
|
||||
let kernel = tensor
|
||||
.context
|
||||
.compile_static::<KernelSettings<IndexRaw, E, i32, WORKGROUP, WORKGROUP, 1>>();
|
||||
|
||||
tensor.context.execute(
|
||||
elemwise_workgroup(num_elems, WORKGROUP),
|
||||
kernel,
|
||||
&[&tensor.buffer, &output.buffer, &info_buffer],
|
||||
);
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
pub(crate) fn slice_assign<E: WgpuElement, const D1: usize, const D2: usize>(
|
||||
tensor: WgpuTensor<E, D1>,
|
||||
indices: [Range<usize>; D2],
|
||||
value: WgpuTensor<E, D1>,
|
||||
) -> WgpuTensor<E, D1> {
|
||||
const WORKGROUP: usize = 32;
|
||||
|
||||
let tensor = match tensor.can_mut() {
|
||||
true => tensor,
|
||||
false => tensor.copy(),
|
||||
};
|
||||
let num_elems = tensor.shape.num_elements();
|
||||
let mut info = build_info(&[&tensor, &value]);
|
||||
|
||||
for i in 0..D1 {
|
||||
let start = indices.get(i).map(|index| index.start).unwrap_or(0);
|
||||
info.push(start as u32);
|
||||
}
|
||||
|
||||
let info_buffer = tensor
|
||||
.context
|
||||
.create_buffer_with_data(bytemuck::cast_slice(&info));
|
||||
|
||||
let kernel = tensor.context.compile_static::<KernelSettings<
|
||||
IndexAssignInplaceRaw,
|
||||
E,
|
||||
i32,
|
||||
WORKGROUP,
|
||||
WORKGROUP,
|
||||
1,
|
||||
>>();
|
||||
|
||||
tensor.context.execute(
|
||||
elemwise_workgroup(num_elems, WORKGROUP),
|
||||
kernel,
|
||||
&[&tensor.buffer, &value.buffer, &info_buffer],
|
||||
);
|
||||
|
||||
tensor
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tests::{ReferenceBackend, TestBackend};
|
||||
use burn_tensor::{Distribution, Tensor};
|
||||
|
||||
#[test]
|
||||
fn slice_should_work_with_multiple_workgroups() {
|
||||
let tensor = Tensor::<TestBackend, 2>::random([6, 256], Distribution::Standard);
|
||||
let indices = [3..5, 45..256];
|
||||
let tensor_ref = Tensor::<ReferenceBackend, 2>::from_data(tensor.to_data());
|
||||
|
||||
let actual = slice(tensor.into_primitive(), indices.clone());
|
||||
let expected = tensor_ref.slice(indices);
|
||||
|
||||
expected.into_data().assert_approx_eq(
|
||||
&Tensor::<TestBackend, 2>::from_primitive(actual).into_data(),
|
||||
3,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn slice_assign_should_work_with_multiple_workgroups() {
|
||||
let tensor = Tensor::<TestBackend, 2>::random([6, 256], Distribution::Standard);
|
||||
let value = Tensor::<TestBackend, 2>::random([2, 211], Distribution::Standard);
|
||||
let indices = [3..5, 45..256];
|
||||
let tensor_ref = Tensor::<ReferenceBackend, 2>::from_data(tensor.to_data());
|
||||
let value_ref = Tensor::<ReferenceBackend, 2>::from_data(value.to_data());
|
||||
|
||||
let actual = slice_assign(
|
||||
tensor.into_primitive(),
|
||||
indices.clone(),
|
||||
value.into_primitive(),
|
||||
);
|
||||
let expected = tensor_ref.slice_assign(indices, value_ref);
|
||||
|
||||
expected.into_data().assert_approx_eq(
|
||||
&Tensor::<TestBackend, 2>::from_primitive(actual).into_data(),
|
||||
3,
|
||||
);
|
||||
}
|
||||
}
|
|
@ -3,6 +3,7 @@ mod binary_elemwise;
|
|||
mod cat;
|
||||
mod comparison;
|
||||
mod comparison_elem;
|
||||
mod index;
|
||||
mod mask;
|
||||
mod matmul;
|
||||
mod reduction;
|
||||
|
@ -20,5 +21,6 @@ pub use unary_scalar::*;
|
|||
pub(crate) use cat::*;
|
||||
pub(crate) use comparison::*;
|
||||
pub(crate) use comparison_elem::*;
|
||||
pub(crate) use index::*;
|
||||
pub(crate) use mask::*;
|
||||
pub(crate) use reduction::*;
|
||||
|
|
|
@ -59,14 +59,14 @@ mod tests {
|
|||
burn_tensor::testgen_matmul!();
|
||||
burn_tensor::testgen_reshape!();
|
||||
burn_tensor::testgen_transpose!();
|
||||
burn_tensor::testgen_index!();
|
||||
burn_tensor::testgen_slice!();
|
||||
burn_tensor::testgen_aggregation!();
|
||||
burn_tensor::testgen_arg!();
|
||||
burn_tensor::testgen_map_comparison!();
|
||||
burn_tensor::testgen_arange!();
|
||||
burn_tensor::testgen_mask!();
|
||||
burn_tensor::testgen_cat!();
|
||||
burn_tensor::testgen_index_select!();
|
||||
burn_tensor::testgen_select!();
|
||||
burn_tensor::testgen_gather_scatter!();
|
||||
|
||||
type TestADBackend = burn_autodiff::ADBackendDecorator<TestBackend>;
|
||||
|
@ -88,11 +88,11 @@ mod tests {
|
|||
burn_autodiff::testgen_ad_matmul!();
|
||||
burn_autodiff::testgen_ad_reshape!();
|
||||
burn_autodiff::testgen_ad_transpose!();
|
||||
burn_autodiff::testgen_ad_index!();
|
||||
burn_autodiff::testgen_ad_slice!();
|
||||
burn_autodiff::testgen_ad_aggregation!();
|
||||
burn_autodiff::testgen_ad_cat!();
|
||||
burn_autodiff::testgen_ad_mask!();
|
||||
burn_autodiff::testgen_ad_index_select!();
|
||||
burn_autodiff::testgen_ad_select!();
|
||||
|
||||
// Once all operations will be implemented.
|
||||
// burn_tensor::testgen_all!();
|
||||
|
|
|
@ -1,18 +1,16 @@
|
|||
use crate::{
|
||||
comparison, comparison_elem, comparison_elem_inplace, comparison_inplace,
|
||||
context::WorkGroup,
|
||||
element::WgpuElement,
|
||||
kernel::{
|
||||
build_info, cat, comparison, comparison_elem, comparison_elem_inplace, comparison_inplace,
|
||||
mask_fill, mask_fill_inplace, mask_where, mask_where_inplace, KernelSettings,
|
||||
self, cat, comparison, comparison_elem, comparison_elem_inplace, comparison_inplace,
|
||||
mask_fill, mask_fill_inplace, mask_where, mask_where_inplace,
|
||||
},
|
||||
kernel_wgsl,
|
||||
pool::get_context,
|
||||
tensor::WgpuTensor,
|
||||
GraphicsApi, WgpuDevice,
|
||||
};
|
||||
use burn_tensor::{backend::Backend, Data, Shape};
|
||||
use std::{marker::PhantomData, mem, ops::Range};
|
||||
use std::{marker::PhantomData, mem};
|
||||
|
||||
pub type FloatElem<B> = <B as Backend>::FloatElem;
|
||||
pub type Device<B> = <B as Backend>::Device;
|
||||
|
@ -63,7 +61,7 @@ impl<G: GraphicsApi> BaseOps<G> {
|
|||
}
|
||||
|
||||
pub fn into_data<E: WgpuElement, const D: usize>(tensor: WgpuTensor<E, D>) -> Data<E, D> {
|
||||
let tensor = Self::into_continuous(tensor);
|
||||
let tensor = kernel::into_continuous(tensor);
|
||||
let bytes = tensor.context.read_buffer(tensor.buffer);
|
||||
let values = E::from_bytes(&bytes);
|
||||
|
||||
|
@ -109,230 +107,11 @@ impl<G: GraphicsApi> BaseOps<G> {
|
|||
shape: Shape<D2>,
|
||||
) -> WgpuTensor<E, D2> {
|
||||
// TODO: Not force standard layout all the time (improve performance).
|
||||
let tensor = Self::into_continuous(tensor);
|
||||
let tensor = kernel::into_continuous(tensor);
|
||||
|
||||
WgpuTensor::new(tensor.context, shape, tensor.buffer)
|
||||
}
|
||||
|
||||
pub fn into_continuous<E: WgpuElement, const D: usize>(
|
||||
tensor: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
if tensor.is_continuous() {
|
||||
return tensor;
|
||||
}
|
||||
|
||||
kernel_wgsl!(ContinuousRaw, "../template/continuous.wgsl");
|
||||
|
||||
let buffer = tensor
|
||||
.context
|
||||
.create_buffer(tensor.shape.num_elements() * core::mem::size_of::<E>());
|
||||
let output = WgpuTensor::new(tensor.context.clone(), tensor.shape.clone(), buffer);
|
||||
let info = build_info(&[&tensor, &output]);
|
||||
let info_buffer = tensor
|
||||
.context
|
||||
.create_buffer_with_data(bytemuck::cast_slice(&info));
|
||||
|
||||
let kernel = tensor
|
||||
.context
|
||||
.compile_static::<KernelSettings<ContinuousRaw, E, i32, 256, 1, 1>>();
|
||||
|
||||
tensor.context.execute(
|
||||
WorkGroup::new(
|
||||
f32::ceil(output.shape.num_elements() as f32 / 256_f32) as u32,
|
||||
1,
|
||||
1,
|
||||
),
|
||||
kernel,
|
||||
&[&tensor.buffer, &output.buffer, &info_buffer],
|
||||
);
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
pub fn index<E: WgpuElement, const D1: usize, const D2: usize>(
|
||||
tensor: WgpuTensor<E, D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
) -> WgpuTensor<E, D1> {
|
||||
kernel_wgsl!(IndexRaw, "../template/index/index.wgsl");
|
||||
|
||||
let mut dims = tensor.shape.dims;
|
||||
|
||||
for i in 0..D2 {
|
||||
dims[i] = indexes[i].end - indexes[i].start;
|
||||
}
|
||||
|
||||
let shape_output = Shape::new(dims);
|
||||
|
||||
let buffer = tensor
|
||||
.context
|
||||
.create_buffer(shape_output.num_elements() * core::mem::size_of::<E>());
|
||||
let output = WgpuTensor::new(tensor.context.clone(), shape_output, buffer);
|
||||
let mut info = build_info(&[&tensor, &output]);
|
||||
|
||||
for i in 0..D1 {
|
||||
let start = indexes.get(i).map(|index| index.start).unwrap_or(0);
|
||||
info.push(start as u32);
|
||||
}
|
||||
|
||||
let info_buffer = tensor
|
||||
.context
|
||||
.create_buffer_with_data(bytemuck::cast_slice(&info));
|
||||
|
||||
let kernel = tensor
|
||||
.context
|
||||
.compile_static::<KernelSettings<IndexRaw, E, i32, 256, 1, 1>>();
|
||||
|
||||
tensor.context.execute(
|
||||
WorkGroup::new(
|
||||
f32::ceil(output.shape.num_elements() as f32 / 256_f32) as u32,
|
||||
1,
|
||||
1,
|
||||
),
|
||||
kernel,
|
||||
&[&tensor.buffer, &output.buffer, &info_buffer],
|
||||
);
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
pub fn index_assign<E: WgpuElement, const D1: usize, const D2: usize>(
|
||||
tensor: WgpuTensor<E, D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
value: WgpuTensor<E, D1>,
|
||||
) -> WgpuTensor<E, D1> {
|
||||
kernel_wgsl!(
|
||||
IndexAssignInplaceRaw,
|
||||
"../template/index/index_assign_inplace.wgsl"
|
||||
);
|
||||
|
||||
let tensor = match tensor.can_mut() {
|
||||
true => tensor,
|
||||
false => tensor.copy(),
|
||||
};
|
||||
|
||||
let mut info = build_info(&[&tensor, &value]);
|
||||
|
||||
for i in 0..D1 {
|
||||
let start = indexes.get(i).map(|index| index.start).unwrap_or(0);
|
||||
info.push(start as u32);
|
||||
}
|
||||
|
||||
let info_buffer = tensor
|
||||
.context
|
||||
.create_buffer_with_data(bytemuck::cast_slice(&info));
|
||||
|
||||
let kernel = tensor
|
||||
.context
|
||||
.compile_static::<KernelSettings<IndexAssignInplaceRaw, E, i32, 256, 1, 1>>();
|
||||
|
||||
tensor.context.execute(
|
||||
WorkGroup::new(
|
||||
f32::ceil(value.shape.num_elements() as f32 / 256_f32) as u32,
|
||||
1,
|
||||
1,
|
||||
),
|
||||
kernel,
|
||||
&[&tensor.buffer, &value.buffer, &info_buffer],
|
||||
);
|
||||
|
||||
tensor
|
||||
}
|
||||
|
||||
pub fn index_select<E: WgpuElement, I: WgpuElement, const D: usize>(
|
||||
tensor: WgpuTensor<E, D>,
|
||||
dim: usize,
|
||||
indexes: WgpuTensor<I, 1>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
kernel_wgsl!(IndexSelect, "../template/index/index_select.wgsl");
|
||||
|
||||
let mut output_shape = tensor.shape.clone();
|
||||
output_shape.dims[dim] = indexes.shape.dims[0];
|
||||
|
||||
let buffer = tensor
|
||||
.context
|
||||
.create_buffer(std::mem::size_of::<E>() * output_shape.num_elements());
|
||||
let output = WgpuTensor::new(tensor.context.clone(), output_shape, buffer);
|
||||
|
||||
let mut info = build_info(&[&tensor, &output]);
|
||||
info.push(dim as u32);
|
||||
|
||||
let info_buffer = tensor
|
||||
.context
|
||||
.create_buffer_with_data(bytemuck::cast_slice(&info));
|
||||
|
||||
let kernel = tensor
|
||||
.context
|
||||
.compile_static::<KernelSettings<IndexSelect, E, I, 256, 1, 1>>();
|
||||
|
||||
tensor.context.execute(
|
||||
WorkGroup::new(
|
||||
f32::ceil(output.shape.num_elements() as f32 / 256_f32) as u32,
|
||||
1,
|
||||
1,
|
||||
),
|
||||
kernel,
|
||||
&[
|
||||
&tensor.buffer,
|
||||
&indexes.buffer,
|
||||
&output.buffer,
|
||||
&info_buffer,
|
||||
],
|
||||
);
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
pub fn index_select_assign<E: WgpuElement, I: WgpuElement, const D: usize, const D2: usize>(
|
||||
tensor: WgpuTensor<E, D>,
|
||||
dim: usize,
|
||||
indexes: WgpuTensor<I, 1>,
|
||||
values: WgpuTensor<E, D2>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
kernel_wgsl!(
|
||||
IndexSelectAssignInplace,
|
||||
"../template/index/index_select_assign_inplace.wgsl"
|
||||
);
|
||||
|
||||
let tensor = match tensor.can_mut() {
|
||||
true => tensor,
|
||||
false => tensor.copy(),
|
||||
};
|
||||
|
||||
let mut shape = tensor.shape.clone();
|
||||
shape.dims[dim] = values.shape.dims[dim];
|
||||
let values = WgpuTensor::new(values.context, shape, values.buffer);
|
||||
let mut info = build_info(&[&tensor, &values]);
|
||||
info.push(dim as u32);
|
||||
|
||||
let info_buffer = tensor
|
||||
.context
|
||||
.create_buffer_with_data(bytemuck::cast_slice(&info));
|
||||
|
||||
let kernel = tensor
|
||||
.context
|
||||
.compile_static::<KernelSettings<IndexSelectAssignInplace, E, I, 256, 1, 1>>();
|
||||
|
||||
let mut shape_tmp = values.shape;
|
||||
shape_tmp.dims[dim] = 1; // Just one thread for the dim.
|
||||
|
||||
tensor.context.execute(
|
||||
WorkGroup::new(
|
||||
f32::ceil(shape_tmp.num_elements() as f32 / 256_f32) as u32,
|
||||
1,
|
||||
1,
|
||||
),
|
||||
kernel,
|
||||
&[
|
||||
&tensor.buffer,
|
||||
&indexes.buffer,
|
||||
&values.buffer,
|
||||
&info_buffer,
|
||||
],
|
||||
);
|
||||
|
||||
tensor
|
||||
}
|
||||
|
||||
pub fn equal<E: WgpuElement, const D: usize>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
rhs: WgpuTensor<E, D>,
|
||||
|
@ -501,107 +280,4 @@ impl<G: GraphicsApi> BaseOps<G> {
|
|||
) -> WgpuTensor<E, D> {
|
||||
cat(tensors, dim)
|
||||
}
|
||||
|
||||
pub fn gather<E: WgpuElement, I: WgpuElement, const D: usize>(
|
||||
dim: usize,
|
||||
tensor: WgpuTensor<E, D>,
|
||||
indexes: WgpuTensor<I, D>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
kernel_wgsl!(Gather, "../template/gather.wgsl");
|
||||
let shape_output = indexes.shape.clone();
|
||||
let indexes = Self::into_continuous(indexes);
|
||||
|
||||
let buffer = tensor
|
||||
.context
|
||||
.create_buffer(shape_output.num_elements() * core::mem::size_of::<E>());
|
||||
let output = WgpuTensor::new(tensor.context.clone(), shape_output, buffer);
|
||||
let mut info = build_info(&[&tensor, &output]);
|
||||
info.push(dim as u32);
|
||||
let info_buffer = tensor
|
||||
.context
|
||||
.create_buffer_with_data(bytemuck::cast_slice(&info));
|
||||
|
||||
let kernel = tensor
|
||||
.context
|
||||
.compile_static::<KernelSettings<Gather, E, i32, 256, 1, 1>>();
|
||||
|
||||
tensor.context.execute(
|
||||
WorkGroup::new(
|
||||
f32::ceil(output.shape.num_elements() as f32 / 256_f32) as u32,
|
||||
1,
|
||||
1,
|
||||
),
|
||||
kernel,
|
||||
&[
|
||||
&tensor.buffer,
|
||||
&indexes.buffer,
|
||||
&output.buffer,
|
||||
&info_buffer,
|
||||
],
|
||||
);
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
pub fn scatter<E: WgpuElement, I: WgpuElement, const D: usize>(
|
||||
dim: usize,
|
||||
tensor: WgpuTensor<E, D>,
|
||||
indexes: WgpuTensor<I, D>,
|
||||
value: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
kernel_wgsl!(Scatter, "../template/scatter.wgsl");
|
||||
|
||||
const WORKGROUP: usize = 256;
|
||||
|
||||
let indexes = Self::into_continuous(indexes);
|
||||
let tensor = Self::into_continuous(tensor);
|
||||
let value = Self::into_continuous(value);
|
||||
let tensor = match tensor.can_mut() {
|
||||
true => tensor,
|
||||
false => tensor.copy(),
|
||||
};
|
||||
let mut info = build_info(&[&tensor]);
|
||||
let mut strides = [0; D];
|
||||
let mut current = 1;
|
||||
let mut num_elems_per_workgroup = 1;
|
||||
|
||||
tensor
|
||||
.shape
|
||||
.dims
|
||||
.iter()
|
||||
.enumerate()
|
||||
.rev()
|
||||
.filter(|(index, _val)| *index != dim)
|
||||
.for_each(|(index, val)| {
|
||||
strides[index] = current;
|
||||
current *= val;
|
||||
num_elems_per_workgroup *= tensor.shape.dims[index];
|
||||
});
|
||||
|
||||
strides
|
||||
.into_iter()
|
||||
.for_each(|stride| info.push(stride as u32));
|
||||
|
||||
info.push(dim as u32);
|
||||
|
||||
let info_buffer = tensor
|
||||
.context
|
||||
.create_buffer_with_data(bytemuck::cast_slice(&info));
|
||||
|
||||
let kernel = tensor
|
||||
.context
|
||||
.compile_static::<KernelSettings<Scatter, E, i32, WORKGROUP, 1, 1>>();
|
||||
|
||||
tensor.context.execute(
|
||||
WorkGroup::new(
|
||||
f32::ceil(num_elems_per_workgroup as f32 / WORKGROUP as f32) as u32,
|
||||
1,
|
||||
1,
|
||||
),
|
||||
kernel,
|
||||
&[&tensor.buffer, &indexes.buffer, &value.buffer, &info_buffer],
|
||||
);
|
||||
|
||||
tensor
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,14 +1,12 @@
|
|||
use std::ops::Range;
|
||||
|
||||
use burn_tensor::{ops::BoolTensorOps, ops::IntTensorOps, Data, Shape};
|
||||
|
||||
use super::{BaseOps, BoolTensor, Device, IntTensor};
|
||||
use crate::{
|
||||
element::{FloatElement, IntElement},
|
||||
kernel,
|
||||
tensor::WgpuTensor,
|
||||
GraphicsApi, WgpuBackend,
|
||||
};
|
||||
|
||||
use super::{BaseOps, BoolTensor, Device, IntTensor};
|
||||
use burn_tensor::{ops::BoolTensorOps, ops::IntTensorOps, Data, Shape};
|
||||
use std::ops::Range;
|
||||
|
||||
impl<G, F, I> BoolTensorOps<WgpuBackend<G, F, I>> for WgpuBackend<G, F, I>
|
||||
where
|
||||
|
@ -76,19 +74,19 @@ where
|
|||
BaseOps::<G>::reshape(tensor, shape)
|
||||
}
|
||||
|
||||
fn bool_index<const D1: usize, const D2: usize>(
|
||||
fn bool_slice<const D1: usize, const D2: usize>(
|
||||
tensor: BoolTensor<Self, D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
ranges: [Range<usize>; D2],
|
||||
) -> BoolTensor<Self, D1> {
|
||||
BaseOps::<G>::index(tensor, indexes)
|
||||
kernel::slice(tensor, ranges)
|
||||
}
|
||||
|
||||
fn bool_index_assign<const D1: usize, const D2: usize>(
|
||||
fn bool_slice_assign<const D1: usize, const D2: usize>(
|
||||
tensor: BoolTensor<Self, D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
ranges: [Range<usize>; D2],
|
||||
value: BoolTensor<Self, D1>,
|
||||
) -> BoolTensor<Self, D1> {
|
||||
BaseOps::<G>::index_assign(tensor, indexes, value)
|
||||
kernel::slice_assign(tensor, ranges, value)
|
||||
}
|
||||
|
||||
fn bool_cat<const D: usize>(
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use super::numeric::NumericOps;
|
||||
use super::{BaseOps, BoolTensor, Device, FloatElem, FloatTensor, IntTensor};
|
||||
use crate::kernel::{
|
||||
matmul_tiling_2d_default, unary_default, unary_inplace_default, unary_scalar_default,
|
||||
self, matmul_tiling_2d_default, unary_default, unary_inplace_default, unary_scalar_default,
|
||||
unary_scalar_inplace_default,
|
||||
};
|
||||
use crate::unary_scalar_inplace;
|
||||
|
@ -138,8 +138,8 @@ where
|
|||
lhs: FloatTensor<Self, D>,
|
||||
rhs: FloatTensor<Self, D>,
|
||||
) -> FloatTensor<Self, D> {
|
||||
let lhs = BaseOps::<G>::into_continuous(lhs);
|
||||
let rhs = BaseOps::<G>::into_continuous(rhs);
|
||||
let lhs = kernel::into_continuous(lhs);
|
||||
let rhs = kernel::into_continuous(rhs);
|
||||
|
||||
matmul_tiling_2d_default::<FloatElem<Self>, D>(lhs, rhs)
|
||||
}
|
||||
|
@ -162,50 +162,50 @@ where
|
|||
fn gather<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: FloatTensor<Self, D>,
|
||||
indexes: IntTensor<Self, D>,
|
||||
indices: IntTensor<Self, D>,
|
||||
) -> FloatTensor<Self, D> {
|
||||
BaseOps::<G>::gather(dim, tensor, indexes)
|
||||
kernel::gather(dim, tensor, indices)
|
||||
}
|
||||
|
||||
fn scatter<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: FloatTensor<Self, D>,
|
||||
indexes: IntTensor<Self, D>,
|
||||
indices: IntTensor<Self, D>,
|
||||
value: FloatTensor<Self, D>,
|
||||
) -> FloatTensor<Self, D> {
|
||||
BaseOps::<G>::scatter(dim, tensor, indexes, value)
|
||||
kernel::scatter(dim, tensor, indices, value)
|
||||
}
|
||||
|
||||
fn index_select<const D: usize>(
|
||||
fn select<const D: usize>(
|
||||
tensor: FloatTensor<Self, D>,
|
||||
dim: usize,
|
||||
indexes: IntTensor<Self, 1>,
|
||||
indices: IntTensor<Self, 1>,
|
||||
) -> FloatTensor<Self, D> {
|
||||
BaseOps::<G>::index_select(tensor, dim, indexes)
|
||||
kernel::select(tensor, dim, indices)
|
||||
}
|
||||
|
||||
fn index_select_assign<const D1: usize, const D2: usize>(
|
||||
tensor: FloatTensor<Self, D1>,
|
||||
fn select_assign<const D: usize>(
|
||||
tensor: FloatTensor<Self, D>,
|
||||
dim: usize,
|
||||
indexes: IntTensor<Self, 1>,
|
||||
value: FloatTensor<Self, D2>,
|
||||
) -> FloatTensor<Self, D1> {
|
||||
BaseOps::<G>::index_select_assign(tensor, dim, indexes, value)
|
||||
indices: IntTensor<Self, 1>,
|
||||
value: FloatTensor<Self, D>,
|
||||
) -> FloatTensor<Self, D> {
|
||||
kernel::select_assign(tensor, dim, indices, value)
|
||||
}
|
||||
|
||||
fn index<const D1: usize, const D2: usize>(
|
||||
fn slice<const D1: usize, const D2: usize>(
|
||||
tensor: FloatTensor<Self, D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
ranges: [Range<usize>; D2],
|
||||
) -> FloatTensor<Self, D1> {
|
||||
BaseOps::<G>::index(tensor, indexes)
|
||||
kernel::slice(tensor, ranges)
|
||||
}
|
||||
|
||||
fn index_assign<const D1: usize, const D2: usize>(
|
||||
fn slice_assign<const D1: usize, const D2: usize>(
|
||||
tensor: FloatTensor<Self, D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
ranges: [Range<usize>; D2],
|
||||
value: FloatTensor<Self, D1>,
|
||||
) -> FloatTensor<Self, D1> {
|
||||
BaseOps::<G>::index_assign(tensor, indexes, value)
|
||||
kernel::slice_assign(tensor, ranges, value)
|
||||
}
|
||||
|
||||
fn mask_where<const D: usize>(
|
||||
|
|
|
@ -1,11 +1,9 @@
|
|||
use std::ops::Range;
|
||||
|
||||
use burn_tensor::{backend::Backend, ops::IntTensorOps, Data, Shape};
|
||||
|
||||
use crate::{
|
||||
element::{FloatElement, IntElement},
|
||||
GraphicsApi, WgpuBackend,
|
||||
kernel, GraphicsApi, WgpuBackend,
|
||||
};
|
||||
use burn_tensor::{ops::IntTensorOps, Data, Shape};
|
||||
use std::ops::Range;
|
||||
|
||||
use super::{numeric::NumericOps, BaseOps, BoolTensor, Device, IntElem, IntTensor};
|
||||
|
||||
|
@ -52,19 +50,19 @@ where
|
|||
BaseOps::<G>::reshape(tensor, shape)
|
||||
}
|
||||
|
||||
fn int_index<const D1: usize, const D2: usize>(
|
||||
fn int_slice<const D1: usize, const D2: usize>(
|
||||
tensor: IntTensor<Self, D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
ranges: [Range<usize>; D2],
|
||||
) -> IntTensor<Self, D1> {
|
||||
BaseOps::<G>::index(tensor, indexes)
|
||||
kernel::slice(tensor, ranges)
|
||||
}
|
||||
|
||||
fn int_index_assign<const D1: usize, const D2: usize>(
|
||||
fn int_slice_assign<const D1: usize, const D2: usize>(
|
||||
tensor: IntTensor<Self, D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
ranges: [Range<usize>; D2],
|
||||
value: IntTensor<Self, D1>,
|
||||
) -> IntTensor<Self, D1> {
|
||||
BaseOps::<G>::index_assign(tensor, indexes, value)
|
||||
kernel::slice_assign(tensor, ranges, value)
|
||||
}
|
||||
|
||||
fn int_mask_where<const D: usize>(
|
||||
|
@ -86,35 +84,35 @@ where
|
|||
fn int_gather<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: IntTensor<Self, D>,
|
||||
indexes: IntTensor<Self, D>,
|
||||
indices: IntTensor<Self, D>,
|
||||
) -> IntTensor<Self, D> {
|
||||
BaseOps::<G>::gather(dim, tensor, indexes)
|
||||
kernel::gather(dim, tensor, indices)
|
||||
}
|
||||
|
||||
fn int_scatter<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: IntTensor<Self, D>,
|
||||
indexes: IntTensor<Self, D>,
|
||||
indices: IntTensor<Self, D>,
|
||||
value: IntTensor<Self, D>,
|
||||
) -> IntTensor<Self, D> {
|
||||
BaseOps::<G>::scatter(dim, tensor, indexes, value)
|
||||
kernel::scatter(dim, tensor, indices, value)
|
||||
}
|
||||
|
||||
fn int_index_select_dim<const D: usize>(
|
||||
_tensor: <WgpuBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||
_dim: usize,
|
||||
_indexes: <WgpuBackend<G, F, I> as Backend>::IntTensorPrimitive<1>,
|
||||
) -> <WgpuBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
|
||||
todo!()
|
||||
fn int_select<const D: usize>(
|
||||
tensor: IntTensor<Self, D>,
|
||||
dim: usize,
|
||||
indices: IntTensor<Self, 1>,
|
||||
) -> IntTensor<Self, D> {
|
||||
kernel::select(tensor, dim, indices)
|
||||
}
|
||||
|
||||
fn int_index_select_dim_assign<const D1: usize, const D2: usize>(
|
||||
_tensor: <WgpuBackend<G, F, I> as Backend>::IntTensorPrimitive<D1>,
|
||||
_dim: usize,
|
||||
_indexes: <WgpuBackend<G, F, I> as Backend>::IntTensorPrimitive<1>,
|
||||
_value: <WgpuBackend<G, F, I> as Backend>::IntTensorPrimitive<D2>,
|
||||
) -> <WgpuBackend<G, F, I> as Backend>::IntTensorPrimitive<D1> {
|
||||
todo!()
|
||||
fn int_select_assign<const D: usize>(
|
||||
tensor: IntTensor<Self, D>,
|
||||
dim: usize,
|
||||
indices: IntTensor<Self, 1>,
|
||||
value: IntTensor<Self, D>,
|
||||
) -> IntTensor<Self, D> {
|
||||
kernel::select_assign(tensor, dim, indices, value)
|
||||
}
|
||||
|
||||
fn int_cat<const D: usize>(tensors: Vec<IntTensor<Self, D>>, dim: usize) -> IntTensor<Self, D> {
|
||||
|
|
|
@ -57,22 +57,22 @@ where
|
|||
todo!()
|
||||
}
|
||||
|
||||
fn max_pool2d_with_indexes(
|
||||
fn max_pool2d_with_indices(
|
||||
_x: <WgpuBackend<G, F, I> as Backend>::TensorPrimitive<4>,
|
||||
_kernel_size: [usize; 2],
|
||||
_stride: [usize; 2],
|
||||
_padding: [usize; 2],
|
||||
) -> burn_tensor::ops::MaxPool2dWithIndexes<WgpuBackend<G, F, I>> {
|
||||
) -> burn_tensor::ops::MaxPool2dWithIndices<WgpuBackend<G, F, I>> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn max_pool2d_with_indexes_backward(
|
||||
fn max_pool2d_with_indices_backward(
|
||||
_x: <WgpuBackend<G, F, I> as Backend>::TensorPrimitive<4>,
|
||||
_kernel_size: [usize; 2],
|
||||
_stride: [usize; 2],
|
||||
_padding: [usize; 2],
|
||||
_output_grad: <WgpuBackend<G, F, I> as Backend>::TensorPrimitive<4>,
|
||||
_indexes: <WgpuBackend<G, F, I> as Backend>::IntTensorPrimitive<4>,
|
||||
_indices: <WgpuBackend<G, F, I> as Backend>::IntTensorPrimitive<4>,
|
||||
) -> burn_tensor::ops::MaxPool2dBackward<WgpuBackend<G, F, I>> {
|
||||
todo!()
|
||||
}
|
||||
|
|
|
@ -4,7 +4,7 @@ var<storage, read> input: array<{{ elem }}>;
|
|||
|
||||
@group(0)
|
||||
@binding(1)
|
||||
var<storage, read> indexes: array<{{ int }}>;
|
||||
var<storage, read> indices: array<{{ int }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(2)
|
||||
|
@ -14,9 +14,15 @@ var<storage, read_write> output: array<{{ elem }}>;
|
|||
@binding(3)
|
||||
var<storage, read> info: array<u32>;
|
||||
|
||||
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
|
||||
|
||||
@compute
|
||||
@workgroup_size({{ workgroup_size_x }}, 1, 1)
|
||||
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
|
||||
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
|
||||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(num_workgroups) num_workgroups: vec3<u32>,
|
||||
) {
|
||||
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
|
||||
let rank = info[0];
|
||||
let dim = info[4u * rank + 1u];
|
||||
|
||||
|
@ -31,9 +37,9 @@ fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
|
|||
} else {
|
||||
let stride_output = info[i + rank];
|
||||
let shape_output = info[i + 3u * rank];
|
||||
index_input += global_id.x / stride_output % shape_output * stride_input;
|
||||
index_input += id / stride_output % shape_output * stride_input;
|
||||
}
|
||||
}
|
||||
|
||||
output[global_id.x] = input[index_input + u32(indexes[global_id.x]) * stride];
|
||||
output[id] = input[index_input + u32(indices[id]) * stride];
|
||||
}
|
|
@ -1,62 +0,0 @@
|
|||
@group(0)
|
||||
@binding(0)
|
||||
var<storage, read_write> input: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(1)
|
||||
var<storage, read> indexes: array<{{ int }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(2)
|
||||
var<storage, read> values: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(3)
|
||||
var<storage, read> info: array<u32>;
|
||||
|
||||
@compute
|
||||
@workgroup_size({{ workgroup_size_x }}, 1, 1)
|
||||
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
|
||||
let rank = info[0];
|
||||
let dim = info[4u * rank + 1u];
|
||||
|
||||
var index_input_offset = 0u;
|
||||
var index_values_offset = 0u;
|
||||
|
||||
var stride_input_dim = 0u;
|
||||
var stride_values_dim = 0u;
|
||||
|
||||
var shape_input_dim = 0u;
|
||||
var shape_values_dim = 0u;
|
||||
|
||||
var num_elem = 1u;
|
||||
|
||||
for (var i = 1u; i <= rank; i++) {
|
||||
let stride_input = info[i];
|
||||
let stride_values = info[i + rank];
|
||||
let shape_input = info[i + 2u * rank];
|
||||
let shape_values = info[i + 3u * rank];
|
||||
|
||||
if i - 1u != dim {
|
||||
index_input_offset += global_id.x / stride_input % shape_input * stride_input;
|
||||
index_values_offset += global_id.x / stride_values % shape_values * stride_values;
|
||||
num_elem += shape_input;
|
||||
} else {
|
||||
shape_input_dim = shape_input;
|
||||
shape_values_dim = shape_values;
|
||||
|
||||
stride_input_dim = stride_input;
|
||||
stride_values_dim = stride_values;
|
||||
}
|
||||
}
|
||||
|
||||
if global_id.x > num_elem {
|
||||
return;
|
||||
}
|
||||
|
||||
for (var i = 0u; i < shape_values_dim; i++) {
|
||||
let index = u32(indexes[i]);
|
||||
input[index_input_offset + index * stride_input_dim] += values[index_values_offset + i * stride_values_dim];
|
||||
}
|
||||
}
|
||||
|
|
@ -4,7 +4,7 @@ var<storage, read_write> input: array<{{ elem }}>;
|
|||
|
||||
@group(0)
|
||||
@binding(1)
|
||||
var<storage, read> indexes: array<{{ int }}>;
|
||||
var<storage, read> indices: array<{{ int }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(2)
|
||||
|
@ -14,11 +14,18 @@ var<storage, read> value: array<{{ elem }}>;
|
|||
@binding(3)
|
||||
var<storage, read> info: array<u32>;
|
||||
|
||||
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
|
||||
|
||||
@compute
|
||||
@workgroup_size({{ workgroup_size_x }}, 1, 1)
|
||||
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
|
||||
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
|
||||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(num_workgroups) num_workgroups: vec3<u32>,
|
||||
) {
|
||||
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
|
||||
let rank = info[0];
|
||||
let dim = info[3u * rank + 1u];
|
||||
|
||||
let shape = info[dim + rank + 1u];
|
||||
let stride = info[dim + 1u];
|
||||
|
||||
|
@ -31,17 +38,17 @@ fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
|
|||
let shape_input = info[i + rank];
|
||||
let stride_tmp = info[i + 2u * rank];
|
||||
|
||||
index_offset += global_id.x / stride_tmp % shape_input * stride_input;
|
||||
num_elems *= shape_input;
|
||||
index_offset += id / stride_tmp % shape_input * stride_input;
|
||||
}
|
||||
}
|
||||
|
||||
if global_id.x >= num_elems {
|
||||
if id >= num_elems {
|
||||
return;
|
||||
}
|
||||
|
||||
for (var i = 0u; i < shape; i++) {
|
||||
let index = i * stride + index_offset;
|
||||
input[index_offset + stride * u32(indexes[index])] += value[index];
|
||||
input[index_offset + stride * u32(indices[index])] += value[index];
|
||||
}
|
||||
}
|
|
@ -4,7 +4,7 @@ var<storage, read> input: array<{{ elem }}>;
|
|||
|
||||
@group(0)
|
||||
@binding(1)
|
||||
var<storage, read> indexes: array<{{ int }}>;
|
||||
var<storage, read> indices: array<{{ int }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(2)
|
||||
|
@ -14,9 +14,15 @@ var<storage, read_write> output: array<{{ elem }}>;
|
|||
@binding(3)
|
||||
var<storage, read> info: array<u32>;
|
||||
|
||||
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
|
||||
|
||||
@compute
|
||||
@workgroup_size({{ workgroup_size_x }}, 1, 1)
|
||||
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
|
||||
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
|
||||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(num_workgroups) num_workgroups: vec3<u32>,
|
||||
) {
|
||||
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
|
||||
let rank = info[0];
|
||||
let dim = info[4u * rank + 1u];
|
||||
var index_input = 0u;
|
||||
|
@ -27,15 +33,15 @@ fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
|
|||
let shape_input = info[i + 2u * rank];
|
||||
let shape_output = info[i + 3u * rank];
|
||||
|
||||
let index = global_id.x / stride_output % shape_output;
|
||||
let index = id / stride_output % shape_output;
|
||||
|
||||
if i - 1u == dim {
|
||||
index_input += u32(indexes[index]) * stride_input;
|
||||
index_input += u32(indices[index]) * stride_input;
|
||||
} else {
|
||||
index_input += index * stride_input;
|
||||
}
|
||||
}
|
||||
|
||||
output[global_id.x] = input[index_input];
|
||||
output[id] = input[index_input];
|
||||
}
|
||||
|
|
@ -0,0 +1,61 @@
|
|||
@group(0)
|
||||
@binding(0)
|
||||
var<storage, read_write> input: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(1)
|
||||
var<storage, read> indices: array<{{ int }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(2)
|
||||
var<storage, read> values: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(3)
|
||||
var<storage, read> info: array<u32>;
|
||||
|
||||
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
|
||||
|
||||
@compute
|
||||
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
|
||||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(num_workgroups) num_workgroups: vec3<u32>,
|
||||
) {
|
||||
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
|
||||
let rank = info[0];
|
||||
let dim = info[5u * rank + 1u];
|
||||
|
||||
let dim_stride_input = info[dim + 1u];
|
||||
let dim_stride_value = info[dim + rank + 1u];
|
||||
let dim_shape_value = info[dim + 3u * rank + 1u];
|
||||
|
||||
var num_elems = 1u;
|
||||
var index_input_offset = 0u;
|
||||
var index_value_offset = 0u;
|
||||
|
||||
var num_elem = 1u;
|
||||
|
||||
for (var i = 1u; i <= rank; i++) {
|
||||
if i - 1u != dim {
|
||||
let stride_input = info[i];
|
||||
let stride_value = info[i + rank];
|
||||
let shape_input = info[i + 2u * rank];
|
||||
let shape_value = info[i + 3u * rank];
|
||||
let stride_tmp = info[i + 4u * rank];
|
||||
|
||||
num_elem *= shape_input;
|
||||
index_input_offset += id / stride_tmp % shape_input * stride_input;
|
||||
index_value_offset += id / stride_tmp % shape_value * stride_value;
|
||||
}
|
||||
}
|
||||
|
||||
if id >= num_elem {
|
||||
return;
|
||||
}
|
||||
|
||||
for (var i = 0u; i < dim_shape_value; i++) {
|
||||
let index = u32(indices[i]);
|
||||
input[index_input_offset + index * dim_stride_input] += values[index_value_offset + i * dim_stride_value];
|
||||
}
|
||||
}
|
|
@ -10,9 +10,15 @@ var<storage, read_write> output: array<{{ elem }}>;
|
|||
@binding(2)
|
||||
var<storage, read> info: array<u32>;
|
||||
|
||||
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
|
||||
|
||||
@compute
|
||||
@workgroup_size({{ workgroup_size_x }}, 1, 1)
|
||||
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
|
||||
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
|
||||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(num_workgroups) num_workgroups: vec3<u32>,
|
||||
) {
|
||||
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
|
||||
let dim: u32 = info[0];
|
||||
var index_input: u32 = 0u;
|
||||
|
||||
|
@ -22,10 +28,10 @@ fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
|
|||
let shape_output = info[i + 3u * dim];
|
||||
let start = info[i + 4u * dim];
|
||||
|
||||
let num_block = global_id.x / stride_output % shape_output + start;
|
||||
let num_block = id / stride_output % shape_output + start;
|
||||
|
||||
index_input += num_block * stride_input;
|
||||
}
|
||||
|
||||
output[global_id.x] = input[index_input];
|
||||
output[id] = input[index_input];
|
||||
}
|
|
@ -10,12 +10,19 @@ var<storage, read> value: array<{{ elem }}>;
|
|||
@binding(2)
|
||||
var<storage, read> info: array<u32>;
|
||||
|
||||
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
|
||||
|
||||
@compute
|
||||
@workgroup_size({{ workgroup_size_x }}, 1, 1)
|
||||
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
|
||||
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
|
||||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(num_workgroups) num_workgroups: vec3<u32>,
|
||||
) {
|
||||
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
|
||||
let dim: u32 = info[0];
|
||||
var index_input: u32 = 0u;
|
||||
var index_value: u32 = 0u;
|
||||
var num_elems = 0u;
|
||||
|
||||
for (var i: u32 = 1u; i <= dim; i++) {
|
||||
let stride_input = info[i];
|
||||
|
@ -24,7 +31,7 @@ fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
|
|||
let shape_value = info[i + 3u * dim];
|
||||
let start = info[i + 4u * dim];
|
||||
|
||||
let num_block = global_id.x / stride_value % shape_value;
|
||||
let num_block = id / stride_value % shape_value;
|
||||
|
||||
index_input += (num_block + start) * stride_input;
|
||||
index_value += num_block * stride_value;
|
|
@ -23,7 +23,7 @@ fn main(
|
|||
@builtin(local_invocation_index) local_idx: u32,
|
||||
@builtin(workgroup_id) workgroup_id: vec3<u32>,
|
||||
) {
|
||||
// Indexes
|
||||
// Indices
|
||||
let row = workgroup_id.x * BLOCK_SIZE + (local_idx / BLOCK_SIZE);
|
||||
let col = workgroup_id.y * BLOCK_SIZE + (local_idx % BLOCK_SIZE);
|
||||
let batch = global_id.z;
|
||||
|
|
|
@ -19,7 +19,7 @@ var<storage, read> info: array<u32>;
|
|||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>
|
||||
) {
|
||||
// Indexes
|
||||
// Indices
|
||||
let row = global_id.x;
|
||||
let col = global_id.y;
|
||||
let batch = global_id.z;
|
||||
|
|
|
@ -65,7 +65,7 @@ pub fn infer<B: Backend, D: TextClassificationDataset + 'static>(
|
|||
|
||||
// Print out predictions for each sample
|
||||
for (i, text) in samples.into_iter().enumerate() {
|
||||
let prediction = predictions.clone().index([i..i + 1]); // Get prediction for current sample
|
||||
let prediction = predictions.clone().slice([i..i + 1]); // Get prediction for current sample
|
||||
let logits = prediction.to_data(); // Convert prediction tensor to data
|
||||
let class_index = prediction.argmax(1).into_data().convert::<i32>().value[0]; // Get class index with the highest value
|
||||
let class = D::class_name(class_index as usize); // Get class name
|
||||
|
|
|
@ -110,7 +110,7 @@ impl<B: Backend> TextClassificationModel<B> {
|
|||
let output = self.output.forward(encoded);
|
||||
|
||||
let output_classification = output
|
||||
.index([0..batch_size, 0..1])
|
||||
.slice([0..batch_size, 0..1])
|
||||
.reshape([batch_size, self.n_classes]);
|
||||
|
||||
let loss = CrossEntropyLoss::new(None);
|
||||
|
@ -148,7 +148,7 @@ impl<B: Backend> TextClassificationModel<B> {
|
|||
.forward(TransformerEncoderInput::new(embedding).mask_pad(mask_pad));
|
||||
let output = self.output.forward(encoded);
|
||||
let output = output
|
||||
.index([0..batch_size, 0..1])
|
||||
.slice([0..batch_size, 0..1])
|
||||
.reshape([batch_size, self.n_classes]);
|
||||
|
||||
softmax(output, 1)
|
||||
|
|
|
@ -57,9 +57,9 @@ impl<B: Backend> Batcher<TextGenerationItem, TrainingTextGenerationBatch<B>>
|
|||
let inputs = item
|
||||
.tokens
|
||||
.clone()
|
||||
.index([0..batch_size, 0..seq_length - 1]);
|
||||
let targets = item.tokens.index([0..batch_size, 1..seq_length]);
|
||||
let mask_pad = item.mask_pad.index([0..batch_size, 0..seq_length - 1]);
|
||||
.slice([0..batch_size, 0..seq_length - 1]);
|
||||
let targets = item.tokens.slice([0..batch_size, 1..seq_length]);
|
||||
let mask_pad = item.mask_pad.slice([0..batch_size, 0..seq_length - 1]);
|
||||
|
||||
TrainingTextGenerationBatch::new(inputs, targets, mask_pad)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue