Support signed value for `Tensor::arange` (#1238)

This commit is contained in:
yurzhang 2024-02-07 22:33:01 +08:00 committed by GitHub
parent 5bef9d8432
commit 419e53bc42
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
31 changed files with 212 additions and 145 deletions

133
Cargo.lock generated
View File

@ -1983,7 +1983,7 @@ dependencies = [
"futures-sink",
"futures-util",
"http",
"indexmap 2.2.1",
"indexmap 2.2.2",
"slab",
"tokio",
"tokio-util",
@ -2108,6 +2108,29 @@ dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "hoot"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df22a4d90f1b0e65fe3e0d6ee6a4608cc4d81f4b2eb3e670f44bb6bde711e452"
dependencies = [
"httparse",
"log",
]
[[package]]
name = "hootbin"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "354e60868e49ea1a39c44b9562ad207c4259dc6eabf9863bf3b0f058c55cfdb2"
dependencies = [
"fastrand",
"hoot",
"serde",
"serde_json",
"thiserror",
]
[[package]]
name = "hound"
version = "3.5.1"
@ -2199,9 +2222,9 @@ checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39"
[[package]]
name = "idna"
version = "0.5.0"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6"
checksum = "e14ddfc70884202db2244c223200c204c2bda1bc6e0998d11b5e024d657209e6"
dependencies = [
"unicode-bidi",
"unicode-normalization",
@ -2273,9 +2296,9 @@ dependencies = [
[[package]]
name = "indexmap"
version = "2.2.1"
version = "2.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "433de089bd45971eecf4668ee0ee8f4cec17db4f8bd8f7bc3197a6ce37aa7d9b"
checksum = "824b2ae422412366ba479e8111fd301f7b5faece8149317bb81925979a53f520"
dependencies = [
"equivalent",
"hashbrown 0.14.3",
@ -2659,7 +2682,7 @@ dependencies = [
"bitflags 2.4.2",
"codespan-reporting",
"hexf-parse",
"indexmap 2.2.1",
"indexmap 2.2.2",
"log",
"num-traits",
"rustc-hash",
@ -2768,6 +2791,12 @@ dependencies = [
"num-traits",
]
[[package]]
name = "num-conv"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9"
[[package]]
name = "num-integer"
version = "0.1.45"
@ -3541,7 +3570,7 @@ dependencies = [
"once_cell",
"percent-encoding",
"pin-project-lite",
"rustls-pemfile",
"rustls-pemfile 1.0.4",
"serde",
"serde_json",
"serde_urlencoded",
@ -3669,9 +3698,9 @@ dependencies = [
[[package]]
name = "rustix"
version = "0.38.30"
version = "0.38.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "322394588aaf33c24007e8bb3238ee3e4c5c09c084ab32bc73890b99ff326bca"
checksum = "6ea3e1a662af26cd7a3ba09c0297a31af215563ecf42817c98df621387f4e949"
dependencies = [
"bitflags 2.4.2",
"errno",
@ -3682,24 +3711,27 @@ dependencies = [
[[package]]
name = "rustls"
version = "0.21.10"
version = "0.22.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f9d5a6813c0759e4609cd494e8e725babae6a2ca7b62a5536a13daaec6fcb7ba"
checksum = "e87c9956bd9807afa1f77e0f7594af32566e830e088a5576d27c5b6f30f49d41"
dependencies = [
"log",
"ring",
"rustls-pki-types",
"rustls-webpki",
"sct",
"subtle",
"zeroize",
]
[[package]]
name = "rustls-native-certs"
version = "0.6.3"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a9aace74cb666635c918e9c12bc0d348266037aa8eb599b5cba565709a8dff00"
checksum = "8f1fb85efa936c42c6d5fc28d2629bb51e4b2f4b8a5211e297d599cc5a093792"
dependencies = [
"openssl-probe",
"rustls-pemfile",
"rustls-pemfile 2.0.0",
"rustls-pki-types",
"schannel",
"security-framework",
]
@ -3714,12 +3746,29 @@ dependencies = [
]
[[package]]
name = "rustls-webpki"
version = "0.101.7"
name = "rustls-pemfile"
version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765"
checksum = "35e4980fa29e4c4b212ffb3db068a564cbf560e51d3944b7c88bd8bf5bec64f4"
dependencies = [
"base64 0.21.7",
"rustls-pki-types",
]
[[package]]
name = "rustls-pki-types"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0a716eb65e3158e90e17cd93d855216e27bde02745ab842f2cab4a39dba1bacf"
[[package]]
name = "rustls-webpki"
version = "0.102.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "faaa0a62740bedb9b2ef5afa303da42764c012f743917351dc9a237ea1663610"
dependencies = [
"ring",
"rustls-pki-types",
"untrusted",
]
@ -3798,16 +3847,6 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "sct"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414"
dependencies = [
"ring",
"untrusted",
]
[[package]]
name = "security-framework"
version = "2.9.2"
@ -4156,14 +4195,13 @@ checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160"
[[package]]
name = "synstructure"
version = "0.13.0"
version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "285ba80e733fac80aa4270fbcdf83772a79b80aa35c97075320abfee4a915b06"
checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.48",
"unicode-xid",
]
[[package]]
@ -4364,13 +4402,14 @@ dependencies = [
[[package]]
name = "time"
version = "0.3.31"
version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f657ba42c3f86e7680e53c8cd3af8abbe56b5491790b46e22e19c0d57463583e"
checksum = "fe80ced77cbfb4cb91a94bf72b378b4b6791a0d9b7f09d0be747d1bdff4e68bd"
dependencies = [
"deranged",
"itoa",
"libc",
"num-conv",
"num_threads",
"powerfmt",
"serde",
@ -4386,10 +4425,11 @@ checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3"
[[package]]
name = "time-macros"
version = "0.2.16"
version = "0.2.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26197e33420244aeb70c3e8c78376ca46571bc4e701e4791c2cd9f57dcb3a43f"
checksum = "7ba3a3ef41e6672a2f0f001392bb5dcd3ff0a9992d618ca761a11c3121547774"
dependencies = [
"num-conv",
"time-core",
]
@ -4656,17 +4696,19 @@ checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1"
[[package]]
name = "ureq"
version = "2.9.1"
version = "2.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f8cdd25c339e200129fe4de81451814e5228c9b771d57378817d6117cc2b3f97"
checksum = "399dd89e2af196ae4f83a47bb37a1455e664fe2fed97b3ae68a1c4a3f8216e76"
dependencies = [
"base64 0.21.7",
"flate2",
"hootbin",
"log",
"native-tls",
"once_cell",
"rustls",
"rustls-native-certs",
"rustls-pki-types",
"rustls-webpki",
"serde",
"serde_json",
@ -4676,9 +4718,9 @@ dependencies = [
[[package]]
name = "url"
version = "2.5.0"
version = "2.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "31e6302e3bb753d46e83516cae55ae196fc0c309407cf11ab35cc51a4c2a4633"
checksum = "0d68c799ae75762b8c3fe375feb6600ef5602c883c5d21eb51c09f22b83c4643"
dependencies = [
"form_urlencoded",
"idna",
@ -4858,9 +4900,12 @@ dependencies = [
[[package]]
name = "webpki-roots"
version = "0.25.3"
version = "0.26.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1778a42e8b3b90bff8d0f5032bf22250792889a5cdc752aa0020c84abe3aaf10"
checksum = "b3de34ae270483955a94f4b21bdaaeb83d508bb84a01435f393818edb0012009"
dependencies = [
"rustls-pki-types",
]
[[package]]
name = "weezl"
@ -5294,6 +5339,12 @@ dependencies = [
"synstructure",
]
[[package]]
name = "zeroize"
version = "1.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d"
[[package]]
name = "zip"
version = "0.6.6"

View File

@ -333,4 +333,8 @@ impl<B: Backend> IntTensorOps<Autodiff<B>> for Autodiff<B> {
) -> Vec<<Autodiff<B> as Backend>::IntTensorPrimitive<D>> {
B::int_chunk(tensor, chunks, dim)
}
fn int_arange(range: std::ops::Range<i64>, device: &Device<Self>) -> IntTensor<Self, 1> {
B::int_arange(range, device)
}
}

View File

@ -87,10 +87,6 @@ impl<B: Backend> FloatTensorOps<Self> for Autodiff<B> {
}
}
fn float_arange(range: std::ops::Range<usize>, device: &Device<Self>) -> IntTensor<Self, 1> {
B::float_arange(range, device)
}
fn float_empty<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> FloatTensor<Self, D> {
AutodiffTensor::new(B::float_empty(shape, device))
}

View File

@ -34,7 +34,7 @@ mod tests {
let shape_x = Shape::new([self.batch_size, self.channels, self.length]);
let device = Default::default();
let x = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements(), &device)
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
.reshape(shape_x)
.into_data()
.convert(),

View File

@ -50,7 +50,7 @@ mod tests {
let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]);
let device = Default::default();
let x = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements(), &device)
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
.reshape(shape_x)
.into_data()
.convert(),

View File

@ -79,7 +79,7 @@ mod tests {
let shape_x = Shape::new([self.batch_size, self.channels, self.length]);
let device = Default::default();
let x = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements(), &device)
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
.reshape(shape_x)
.into_data()
.convert(),

View File

@ -106,7 +106,7 @@ mod tests {
let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]);
let device = Default::default();
let x = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements(), &device)
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
.reshape(shape_x)
.into_data()
.convert(),

View File

@ -230,7 +230,7 @@ mod tests {
]);
let device = Default::default();
let weight = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_weight.num_elements(), &device)
TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
.reshape(shape_weight)
.into_data()
.convert(),
@ -238,14 +238,14 @@ mod tests {
)
.require_grad();
let bias = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..self.channels_out, &device)
TestTensorInt::arange(0..self.channels_out as i64, &device)
.into_data()
.convert(),
&device,
)
.require_grad();
let x = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements(), &device)
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
.reshape(shape_x)
.into_data()
.convert(),

View File

@ -781,7 +781,7 @@ mod tests {
]);
let device = Default::default();
let weight = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_weight.num_elements(), &device)
TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
.reshape(shape_weight)
.into_data()
.convert(),
@ -789,14 +789,14 @@ mod tests {
)
.require_grad();
let bias = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..self.channels_out, &device)
TestTensorInt::arange(0..self.channels_out as i64, &device)
.into_data()
.convert(),
&device,
)
.require_grad();
let x = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements(), &device)
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
.reshape(shape_x)
.into_data()
.convert(),

View File

@ -241,7 +241,7 @@ mod tests {
]);
let device = Default::default();
let weight = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_weight.num_elements(), &device)
TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
.reshape(shape_weight)
.into_data()
.convert(),
@ -249,14 +249,14 @@ mod tests {
)
.require_grad();
let bias = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..self.channels[1], &device)
TestTensorInt::arange(0..self.channels[1] as i64, &device)
.into_data()
.convert(),
&device,
)
.require_grad();
let x = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements(), &device)
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
.reshape(shape_x)
.into_data()
.convert(),

View File

@ -654,7 +654,7 @@ mod tests {
]);
let device = Default::default();
let weight = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_weight.num_elements(), &device)
TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
.reshape(shape_weight)
.into_data()
.convert(),
@ -662,14 +662,14 @@ mod tests {
)
.require_grad();
let bias = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..self.channels[1], &device)
TestTensorInt::arange(0..self.channels[1] as i64, &device)
.into_data()
.convert(),
&device,
)
.require_grad();
let x = TestAutodiffTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements(), &device)
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
.reshape(shape_x)
.into_data()
.convert(),

View File

@ -418,4 +418,15 @@ impl<E: TchElement> IntTensorOps<Self> for LibTorch<E> {
) -> Vec<TchTensor<i64, D>> {
TchOps::chunk(tensor, chunks, dim)
}
fn int_arange(range: Range<i64>, device: &LibTorchDevice) -> TchTensor<i64, 1> {
let device: tch::Device = (*device).into();
let mut tensor = tch::Tensor::arange(range.end - range.start, (tch::Kind::Int64, device));
if range.start != 0 {
tensor = tensor.f_add_scalar_(range.start).unwrap();
}
TchTensor::new(tensor)
}
}

View File

@ -42,20 +42,6 @@ impl<E: TchElement> FloatTensorOps<Self> for LibTorch<E> {
}
}
fn float_arange(range: Range<usize>, device: &LibTorchDevice) -> TchTensor<i64, 1> {
let device: tch::Device = (*device).into();
let mut tensor = tch::Tensor::arange(
range.end as i64 - range.start as i64,
(tch::Kind::Int64, device),
);
if range.start != 0 {
tensor = tensor.f_add_scalar_(range.start as i64).unwrap();
}
TchTensor::new(tensor)
}
fn float_repeat<const D: usize>(
tensor: TchTensor<E, D>,
dim: usize,

View File

@ -11,8 +11,8 @@ where
///
/// * `range` - The range of values to generate.
/// * `device` - The device to create the tensor on.
pub fn arange(range: Range<usize>, device: &B::Device) -> Self {
Tensor::new(B::float_arange(range, device))
pub fn arange(range: Range<i64>, device: &B::Device) -> Self {
Tensor::new(B::int_arange(range, device))
}
/// Returns a new integer tensor on the specified device.
@ -21,8 +21,8 @@ where
///
/// * `range` - The range of values to generate.
/// * `step` - The step between each value.
pub fn arange_step(range: Range<usize>, step: usize, device: &B::Device) -> Self {
Tensor::new(B::float_arange_step(range, step, device))
pub fn arange_step(range: Range<i64>, step: usize, device: &B::Device) -> Self {
Tensor::new(B::int_arange_step(range, step, device))
}
}

View File

@ -470,8 +470,8 @@ where
let height = shape.dims[D - 2];
let width = shape.dims[D - 1];
let row_indices: Tensor<B, 1, Int> = Tensor::arange(0..height, &self.device());
let col_indices: Tensor<B, 1, Int> = Tensor::arange(0..width, &self.device());
let row_indices: Tensor<B, 1, Int> = Tensor::arange(0..height as i64, &self.device());
let col_indices: Tensor<B, 1, Int> = Tensor::arange(0..width as i64, &self.device());
let mut row_shape = [1; D];
row_shape[D - 2] = height;
@ -582,7 +582,7 @@ where
///
/// * `size` - The size of the square matrix.
pub fn diagonal(size: usize, device: &B::Device) -> Self {
let indices = Tensor::<B, 1, Int>::arange(0..size, device).unsqueeze();
let indices = Tensor::<B, 1, Int>::arange(0..size as i64, device).unsqueeze();
let ones = K::ones([1, size].into(), device);
let zeros = K::zeros([size, size].into(), device);
Self::new(K::scatter(0, zeros, indices, ones))

View File

@ -958,4 +958,43 @@ pub trait IntTensorOps<B: Backend> {
) -> Vec<IntTensor<B, D>> {
chunk::<B, D, Int>(tensor, chunks, dim)
}
/// Creates a new tensor with values from the given range with the given step size.
///
/// # Arguments
///
/// * `range` - The range of values.
/// * `step` - The step size.
/// * `device` - The device to create the tensor on.
///
/// # Returns
///
/// The tensor with the given values.
fn int_arange_step(range: Range<i64>, step: usize, device: &Device<B>) -> IntTensor<B, 1> {
let value = range
.step_by(step)
.map(|i| i.elem())
.collect::<Vec<IntElem<B>>>();
let shape = Shape::new([value.len()]);
let data = Data::new(value, shape);
B::int_from_data(data, device)
}
/// Creates a new tensor with values from the given range.
///
/// # Arguments
///
/// * `range` - The range of values.
/// * `device` - The device to create the tensor on.
///
/// # Returns
///
/// The tensor with the given values.
///
/// # Remarks
///
/// Uses `arange_step` with a step size of 1 under the hood.
fn int_arange(range: Range<i64>, device: &Device<B>) -> IntTensor<B, 1> {
Self::int_arange_step(range, 1, device)
}
}

View File

@ -148,24 +148,6 @@ pub trait FloatTensorOps<B: Backend> {
device: &Device<B>,
) -> FloatTensor<B, D>;
/// Creates a new tensor with values from the given range.
///
/// # Arguments
///
/// * `range` - The range of values.
/// * `device` - The device to create the tensor on.
///
/// # Returns
///
/// The tensor with the given values.
///
/// # Remarks
///
/// Uses `arange_step` with a step size of 1 under the hood.
fn float_arange(range: Range<usize>, device: &Device<B>) -> IntTensor<B, 1> {
Self::float_arange_step(range, 1, device)
}
/// Converts float tensor to int tensor.
///
/// # Arguments
@ -177,27 +159,6 @@ pub trait FloatTensorOps<B: Backend> {
/// The int tensor with the same data as the float tensor.
fn float_into_int<const D: usize>(tensor: FloatTensor<B, D>) -> IntTensor<B, D>;
/// Creates a new tensor with values from the given range with the given step size.
///
/// # Arguments
///
/// * `range` - The range of values.
/// * `step` - The step size.
/// * `device` - The device to create the tensor on.
///
/// # Returns
///
/// The tensor with the given values.
fn float_arange_step(range: Range<usize>, step: usize, device: &Device<B>) -> IntTensor<B, 1> {
let value = range
.step_by(step)
.map(|i| (i as i64).elem())
.collect::<Vec<IntElem<B>>>();
let shape = Shape::new([value.len()]);
let data = Data::new(value, shape);
B::int_from_data(data, device)
}
/// Creates an empty tensor with the given shape.
///
/// # Arguments

View File

@ -58,7 +58,7 @@ mod tests {
let shape_x = Shape::new([self.batch_size, self.channels, self.length]);
let device = Default::default();
let x = TestTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements(), &device)
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
.reshape(shape_x)
.into_data()
.convert(),

View File

@ -90,7 +90,7 @@ mod tests {
fn assert_output(self, y: TestTensor<4>) {
let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]);
let x = TestTensor::from(
TestTensorInt::arange(0..shape_x.num_elements(), &y.device())
TestTensorInt::arange(0..shape_x.num_elements() as i64, &y.device())
.reshape(shape_x)
.into_data()
.convert(),

View File

@ -69,7 +69,7 @@ mod tests {
fn assert_output(self, y: TestTensor<3>) {
let shape_x = Shape::new([self.batch_size, self.channels, self.length]);
let x = TestTensor::from(
TestTensorInt::arange(0..shape_x.num_elements(), &y.device())
TestTensorInt::arange(0..shape_x.num_elements() as i64, &y.device())
.reshape(shape_x)
.into_data()
.convert(),

View File

@ -94,7 +94,7 @@ mod tests {
fn assert_output(self, y: TestTensor<4>) {
let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]);
let x = TestTensor::from(
TestTensorInt::arange(0..shape_x.num_elements(), &y.device())
TestTensorInt::arange(0..shape_x.num_elements() as i64, &y.device())
.reshape(shape_x)
.into_data()
.convert(),

View File

@ -110,20 +110,20 @@ mod tests {
]);
let device = Default::default();
let weight = TestTensor::from_data(
TestTensorInt::arange(0..shape_weight.num_elements(), &device)
TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
.reshape(shape_weight)
.into_data()
.convert(),
&device,
);
let bias = TestTensor::from_data(
TestTensorInt::arange(0..self.channels_out, &device)
TestTensorInt::arange(0..self.channels_out as i64, &device)
.into_data()
.convert(),
&device,
);
let x = TestTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements(), &device)
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
.reshape(shape_x)
.into_data()
.convert(),

View File

@ -132,18 +132,18 @@ mod tests {
]);
let device = Default::default();
let weight = TestTensor::from(
TestTensorInt::arange(0..shape_weight.num_elements(), &device)
TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
.reshape(shape_weight)
.into_data()
.convert(),
);
let bias = TestTensor::from(
TestTensorInt::arange(0..self.channels_out, &device)
TestTensorInt::arange(0..self.channels_out as i64, &device)
.into_data()
.convert(),
);
let x = TestTensor::from(
TestTensorInt::arange(0..shape_x.num_elements(), &device)
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
.reshape(shape_x)
.into_data()
.convert(),

View File

@ -112,20 +112,20 @@ mod tests {
]);
let device = Default::default();
let weights = TestTensor::from_data(
TestTensorInt::arange(0..shape_weights.num_elements(), &device)
TestTensorInt::arange(0..shape_weights.num_elements() as i64, &device)
.reshape(shape_weights)
.into_data()
.convert(),
&device,
);
let bias = TestTensor::from_data(
TestTensorInt::arange(0..self.channels_out, &device)
TestTensorInt::arange(0..self.channels_out as i64, &device)
.into_data()
.convert(),
&device,
);
let x = TestTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements(), &device)
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
.reshape(shape_x)
.into_data()
.convert(),

View File

@ -302,18 +302,18 @@ mod tests {
]);
let device = Default::default();
let weights = TestTensor::from(
TestTensorInt::arange(0..shape_weights.num_elements(), &device)
TestTensorInt::arange(0..shape_weights.num_elements() as i64, &device)
.reshape(shape_weights)
.into_data()
.convert(),
);
let bias = TestTensor::from(
TestTensorInt::arange(0..self.channels_out, &device)
TestTensorInt::arange(0..self.channels_out as i64, &device)
.into_data()
.convert(),
);
let x = TestTensor::from(
TestTensorInt::arange(0..shape_x.num_elements(), &device)
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
.reshape(shape_x)
.into_data()
.convert(),

View File

@ -90,7 +90,7 @@ mod tests {
fn assert_shape(self, expected_shape: [usize; 3]) {
let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]);
let x = TestTensor::from(
TestTensorInt::arange(0..shape_x.num_elements(), &Default::default())
TestTensorInt::arange(0..shape_x.num_elements() as i64, &Default::default())
.reshape(shape_x)
.into_data()
.convert(),
@ -112,7 +112,7 @@ mod tests {
fn assert_output(self, expected: TestTensor<3>) {
let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]);
let x = TestTensor::from(
TestTensorInt::arange(0..shape_x.num_elements(), &Default::default())
TestTensorInt::arange(0..shape_x.num_elements() as i64, &Default::default())
.reshape(shape_x)
.into_data()
.convert(),

View File

@ -7,8 +7,20 @@ mod tests {
#[test]
fn test_arange() {
let device = <TestBackend as Backend>::Device::default();
let tensor = Tensor::<TestBackend, 1, Int>::arange(2..5, &device);
assert_eq!(tensor.clone().into_data(), Data::from([2, 3, 4]));
assert_eq!(tensor.into_data(), Data::from([2, 3, 4]));
// Test arange with negative numbers
let tensor = Tensor::<TestBackend, 1, Int>::arange(-10..-5, &device);
assert_eq!(tensor.into_data(), Data::from([-10, -9, -8, -7, -6]));
let tensor = Tensor::<TestBackend, 1, Int>::arange(-3..0, &device);
assert_eq!(tensor.into_data(), Data::from([-3, -2, -1]));
// Test arange with a mix of positive and negative numbers
let tensor = Tensor::<TestBackend, 1, Int>::arange(-2..3, &device);
assert_eq!(tensor.clone().into_data(), Data::from([-2, -1, 0, 1, 2]));
assert_eq!(tensor.device(), device);
}
}

View File

@ -18,7 +18,14 @@ mod tests {
// Test correct sequence of numbers when the range is 0..2 and the step is 5
let tensor = Tensor::<TestBackend, 1, Int>::arange_step(0..2, 5, &device);
assert_eq!(tensor.clone().into_data(), Data::from([0]));
assert_eq!(tensor.into_data(), Data::from([0]));
// Test correct sequence of numbers when the range includes negative numbers
let tensor = Tensor::<TestBackend, 1, Int>::arange_step(-3..3, 2, &device);
assert_eq!(tensor.into_data(), Data::from([-3, -1, 1]));
let tensor = Tensor::<TestBackend, 1, Int>::arange_step(-5..1, 5, &device);
assert_eq!(tensor.clone().into_data(), Data::from([-5, 0]));
assert_eq!(tensor.device(), device);
}

View File

@ -71,7 +71,7 @@ mod tests {
const END: usize = 100;
let device = Default::default();
let tensor = Tensor::<TestBackend, 1, Int>::arange(START..END, &device);
let tensor = Tensor::<TestBackend, 1, Int>::arange(START as i64..END as i64, &device);
let tensor_float = cast::<i32, f32, 1>(tensor.clone().into_primitive());
let data_int = tensor.into_data();

View File

@ -96,7 +96,7 @@ impl<B: Backend> TextClassificationModel<B> {
let mask_pad = item.mask_pad.to_device(device);
// Calculate token and position embeddings, and combine them
let index_positions = Tensor::arange(0..seq_length, device)
let index_positions = Tensor::arange(0..seq_length as i64, device)
.reshape([1, seq_length])
.repeat(0, batch_size);
let embedding_positions = self.embedding_pos.forward(index_positions);
@ -136,7 +136,7 @@ impl<B: Backend> TextClassificationModel<B> {
let mask_pad = item.mask_pad.to_device(device);
// Calculate token and position embeddings, and combine them
let index_positions = Tensor::arange(0..seq_length, device)
let index_positions = Tensor::arange(0..seq_length as i64, device)
.reshape([1, seq_length])
.repeat(0, batch_size);
let embedding_positions = self.embedding_pos.forward(index_positions);

View File

@ -64,7 +64,7 @@ impl<B: Backend> TextGenerationModel<B> {
let targets = item.targets.to_device(device);
let mask_pad = item.mask_pad.to_device(device);
let index_positions = Tensor::arange(0..seq_length, device)
let index_positions = Tensor::arange(0..seq_length as i64, device)
.reshape([1, seq_length])
.repeat(0, batch_size);