mirror of https://github.com/tracel-ai/burn.git
Support signed value for `Tensor::arange` (#1238)
This commit is contained in:
parent
5bef9d8432
commit
419e53bc42
|
@ -1983,7 +1983,7 @@ dependencies = [
|
|||
"futures-sink",
|
||||
"futures-util",
|
||||
"http",
|
||||
"indexmap 2.2.1",
|
||||
"indexmap 2.2.2",
|
||||
"slab",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
|
@ -2108,6 +2108,29 @@ dependencies = [
|
|||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hoot"
|
||||
version = "0.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "df22a4d90f1b0e65fe3e0d6ee6a4608cc4d81f4b2eb3e670f44bb6bde711e452"
|
||||
dependencies = [
|
||||
"httparse",
|
||||
"log",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hootbin"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "354e60868e49ea1a39c44b9562ad207c4259dc6eabf9863bf3b0f058c55cfdb2"
|
||||
dependencies = [
|
||||
"fastrand",
|
||||
"hoot",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hound"
|
||||
version = "3.5.1"
|
||||
|
@ -2199,9 +2222,9 @@ checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39"
|
|||
|
||||
[[package]]
|
||||
name = "idna"
|
||||
version = "0.5.0"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6"
|
||||
checksum = "e14ddfc70884202db2244c223200c204c2bda1bc6e0998d11b5e024d657209e6"
|
||||
dependencies = [
|
||||
"unicode-bidi",
|
||||
"unicode-normalization",
|
||||
|
@ -2273,9 +2296,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "indexmap"
|
||||
version = "2.2.1"
|
||||
version = "2.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "433de089bd45971eecf4668ee0ee8f4cec17db4f8bd8f7bc3197a6ce37aa7d9b"
|
||||
checksum = "824b2ae422412366ba479e8111fd301f7b5faece8149317bb81925979a53f520"
|
||||
dependencies = [
|
||||
"equivalent",
|
||||
"hashbrown 0.14.3",
|
||||
|
@ -2659,7 +2682,7 @@ dependencies = [
|
|||
"bitflags 2.4.2",
|
||||
"codespan-reporting",
|
||||
"hexf-parse",
|
||||
"indexmap 2.2.1",
|
||||
"indexmap 2.2.2",
|
||||
"log",
|
||||
"num-traits",
|
||||
"rustc-hash",
|
||||
|
@ -2768,6 +2791,12 @@ dependencies = [
|
|||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-conv"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9"
|
||||
|
||||
[[package]]
|
||||
name = "num-integer"
|
||||
version = "0.1.45"
|
||||
|
@ -3541,7 +3570,7 @@ dependencies = [
|
|||
"once_cell",
|
||||
"percent-encoding",
|
||||
"pin-project-lite",
|
||||
"rustls-pemfile",
|
||||
"rustls-pemfile 1.0.4",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_urlencoded",
|
||||
|
@ -3669,9 +3698,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "rustix"
|
||||
version = "0.38.30"
|
||||
version = "0.38.31"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "322394588aaf33c24007e8bb3238ee3e4c5c09c084ab32bc73890b99ff326bca"
|
||||
checksum = "6ea3e1a662af26cd7a3ba09c0297a31af215563ecf42817c98df621387f4e949"
|
||||
dependencies = [
|
||||
"bitflags 2.4.2",
|
||||
"errno",
|
||||
|
@ -3682,24 +3711,27 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "rustls"
|
||||
version = "0.21.10"
|
||||
version = "0.22.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f9d5a6813c0759e4609cd494e8e725babae6a2ca7b62a5536a13daaec6fcb7ba"
|
||||
checksum = "e87c9956bd9807afa1f77e0f7594af32566e830e088a5576d27c5b6f30f49d41"
|
||||
dependencies = [
|
||||
"log",
|
||||
"ring",
|
||||
"rustls-pki-types",
|
||||
"rustls-webpki",
|
||||
"sct",
|
||||
"subtle",
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-native-certs"
|
||||
version = "0.6.3"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a9aace74cb666635c918e9c12bc0d348266037aa8eb599b5cba565709a8dff00"
|
||||
checksum = "8f1fb85efa936c42c6d5fc28d2629bb51e4b2f4b8a5211e297d599cc5a093792"
|
||||
dependencies = [
|
||||
"openssl-probe",
|
||||
"rustls-pemfile",
|
||||
"rustls-pemfile 2.0.0",
|
||||
"rustls-pki-types",
|
||||
"schannel",
|
||||
"security-framework",
|
||||
]
|
||||
|
@ -3714,12 +3746,29 @@ dependencies = [
|
|||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-webpki"
|
||||
version = "0.101.7"
|
||||
name = "rustls-pemfile"
|
||||
version = "2.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765"
|
||||
checksum = "35e4980fa29e4c4b212ffb3db068a564cbf560e51d3944b7c88bd8bf5bec64f4"
|
||||
dependencies = [
|
||||
"base64 0.21.7",
|
||||
"rustls-pki-types",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-pki-types"
|
||||
version = "1.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0a716eb65e3158e90e17cd93d855216e27bde02745ab842f2cab4a39dba1bacf"
|
||||
|
||||
[[package]]
|
||||
name = "rustls-webpki"
|
||||
version = "0.102.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "faaa0a62740bedb9b2ef5afa303da42764c012f743917351dc9a237ea1663610"
|
||||
dependencies = [
|
||||
"ring",
|
||||
"rustls-pki-types",
|
||||
"untrusted",
|
||||
]
|
||||
|
||||
|
@ -3798,16 +3847,6 @@ version = "1.2.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
|
||||
|
||||
[[package]]
|
||||
name = "sct"
|
||||
version = "0.7.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414"
|
||||
dependencies = [
|
||||
"ring",
|
||||
"untrusted",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "security-framework"
|
||||
version = "2.9.2"
|
||||
|
@ -4156,14 +4195,13 @@ checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160"
|
|||
|
||||
[[package]]
|
||||
name = "synstructure"
|
||||
version = "0.13.0"
|
||||
version = "0.13.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "285ba80e733fac80aa4270fbcdf83772a79b80aa35c97075320abfee4a915b06"
|
||||
checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.48",
|
||||
"unicode-xid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -4364,13 +4402,14 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "time"
|
||||
version = "0.3.31"
|
||||
version = "0.3.32"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f657ba42c3f86e7680e53c8cd3af8abbe56b5491790b46e22e19c0d57463583e"
|
||||
checksum = "fe80ced77cbfb4cb91a94bf72b378b4b6791a0d9b7f09d0be747d1bdff4e68bd"
|
||||
dependencies = [
|
||||
"deranged",
|
||||
"itoa",
|
||||
"libc",
|
||||
"num-conv",
|
||||
"num_threads",
|
||||
"powerfmt",
|
||||
"serde",
|
||||
|
@ -4386,10 +4425,11 @@ checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3"
|
|||
|
||||
[[package]]
|
||||
name = "time-macros"
|
||||
version = "0.2.16"
|
||||
version = "0.2.17"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "26197e33420244aeb70c3e8c78376ca46571bc4e701e4791c2cd9f57dcb3a43f"
|
||||
checksum = "7ba3a3ef41e6672a2f0f001392bb5dcd3ff0a9992d618ca761a11c3121547774"
|
||||
dependencies = [
|
||||
"num-conv",
|
||||
"time-core",
|
||||
]
|
||||
|
||||
|
@ -4656,17 +4696,19 @@ checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1"
|
|||
|
||||
[[package]]
|
||||
name = "ureq"
|
||||
version = "2.9.1"
|
||||
version = "2.9.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f8cdd25c339e200129fe4de81451814e5228c9b771d57378817d6117cc2b3f97"
|
||||
checksum = "399dd89e2af196ae4f83a47bb37a1455e664fe2fed97b3ae68a1c4a3f8216e76"
|
||||
dependencies = [
|
||||
"base64 0.21.7",
|
||||
"flate2",
|
||||
"hootbin",
|
||||
"log",
|
||||
"native-tls",
|
||||
"once_cell",
|
||||
"rustls",
|
||||
"rustls-native-certs",
|
||||
"rustls-pki-types",
|
||||
"rustls-webpki",
|
||||
"serde",
|
||||
"serde_json",
|
||||
|
@ -4676,9 +4718,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "url"
|
||||
version = "2.5.0"
|
||||
version = "2.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "31e6302e3bb753d46e83516cae55ae196fc0c309407cf11ab35cc51a4c2a4633"
|
||||
checksum = "0d68c799ae75762b8c3fe375feb6600ef5602c883c5d21eb51c09f22b83c4643"
|
||||
dependencies = [
|
||||
"form_urlencoded",
|
||||
"idna",
|
||||
|
@ -4858,9 +4900,12 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "webpki-roots"
|
||||
version = "0.25.3"
|
||||
version = "0.26.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1778a42e8b3b90bff8d0f5032bf22250792889a5cdc752aa0020c84abe3aaf10"
|
||||
checksum = "b3de34ae270483955a94f4b21bdaaeb83d508bb84a01435f393818edb0012009"
|
||||
dependencies = [
|
||||
"rustls-pki-types",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "weezl"
|
||||
|
@ -5294,6 +5339,12 @@ dependencies = [
|
|||
"synstructure",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zeroize"
|
||||
version = "1.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d"
|
||||
|
||||
[[package]]
|
||||
name = "zip"
|
||||
version = "0.6.6"
|
||||
|
|
|
@ -333,4 +333,8 @@ impl<B: Backend> IntTensorOps<Autodiff<B>> for Autodiff<B> {
|
|||
) -> Vec<<Autodiff<B> as Backend>::IntTensorPrimitive<D>> {
|
||||
B::int_chunk(tensor, chunks, dim)
|
||||
}
|
||||
|
||||
fn int_arange(range: std::ops::Range<i64>, device: &Device<Self>) -> IntTensor<Self, 1> {
|
||||
B::int_arange(range, device)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -87,10 +87,6 @@ impl<B: Backend> FloatTensorOps<Self> for Autodiff<B> {
|
|||
}
|
||||
}
|
||||
|
||||
fn float_arange(range: std::ops::Range<usize>, device: &Device<Self>) -> IntTensor<Self, 1> {
|
||||
B::float_arange(range, device)
|
||||
}
|
||||
|
||||
fn float_empty<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> FloatTensor<Self, D> {
|
||||
AutodiffTensor::new(B::float_empty(shape, device))
|
||||
}
|
||||
|
|
|
@ -34,7 +34,7 @@ mod tests {
|
|||
let shape_x = Shape::new([self.batch_size, self.channels, self.length]);
|
||||
let device = Default::default();
|
||||
let x = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_x.num_elements(), &device)
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
|
||||
.reshape(shape_x)
|
||||
.into_data()
|
||||
.convert(),
|
||||
|
|
|
@ -50,7 +50,7 @@ mod tests {
|
|||
let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]);
|
||||
let device = Default::default();
|
||||
let x = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_x.num_elements(), &device)
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
|
||||
.reshape(shape_x)
|
||||
.into_data()
|
||||
.convert(),
|
||||
|
|
|
@ -79,7 +79,7 @@ mod tests {
|
|||
let shape_x = Shape::new([self.batch_size, self.channels, self.length]);
|
||||
let device = Default::default();
|
||||
let x = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_x.num_elements(), &device)
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
|
||||
.reshape(shape_x)
|
||||
.into_data()
|
||||
.convert(),
|
||||
|
|
|
@ -106,7 +106,7 @@ mod tests {
|
|||
let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]);
|
||||
let device = Default::default();
|
||||
let x = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_x.num_elements(), &device)
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
|
||||
.reshape(shape_x)
|
||||
.into_data()
|
||||
.convert(),
|
||||
|
|
|
@ -230,7 +230,7 @@ mod tests {
|
|||
]);
|
||||
let device = Default::default();
|
||||
let weight = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_weight.num_elements(), &device)
|
||||
TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
|
||||
.reshape(shape_weight)
|
||||
.into_data()
|
||||
.convert(),
|
||||
|
@ -238,14 +238,14 @@ mod tests {
|
|||
)
|
||||
.require_grad();
|
||||
let bias = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..self.channels_out, &device)
|
||||
TestTensorInt::arange(0..self.channels_out as i64, &device)
|
||||
.into_data()
|
||||
.convert(),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let x = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_x.num_elements(), &device)
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
|
||||
.reshape(shape_x)
|
||||
.into_data()
|
||||
.convert(),
|
||||
|
|
|
@ -781,7 +781,7 @@ mod tests {
|
|||
]);
|
||||
let device = Default::default();
|
||||
let weight = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_weight.num_elements(), &device)
|
||||
TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
|
||||
.reshape(shape_weight)
|
||||
.into_data()
|
||||
.convert(),
|
||||
|
@ -789,14 +789,14 @@ mod tests {
|
|||
)
|
||||
.require_grad();
|
||||
let bias = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..self.channels_out, &device)
|
||||
TestTensorInt::arange(0..self.channels_out as i64, &device)
|
||||
.into_data()
|
||||
.convert(),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let x = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_x.num_elements(), &device)
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
|
||||
.reshape(shape_x)
|
||||
.into_data()
|
||||
.convert(),
|
||||
|
|
|
@ -241,7 +241,7 @@ mod tests {
|
|||
]);
|
||||
let device = Default::default();
|
||||
let weight = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_weight.num_elements(), &device)
|
||||
TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
|
||||
.reshape(shape_weight)
|
||||
.into_data()
|
||||
.convert(),
|
||||
|
@ -249,14 +249,14 @@ mod tests {
|
|||
)
|
||||
.require_grad();
|
||||
let bias = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..self.channels[1], &device)
|
||||
TestTensorInt::arange(0..self.channels[1] as i64, &device)
|
||||
.into_data()
|
||||
.convert(),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let x = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_x.num_elements(), &device)
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
|
||||
.reshape(shape_x)
|
||||
.into_data()
|
||||
.convert(),
|
||||
|
|
|
@ -654,7 +654,7 @@ mod tests {
|
|||
]);
|
||||
let device = Default::default();
|
||||
let weight = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_weight.num_elements(), &device)
|
||||
TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
|
||||
.reshape(shape_weight)
|
||||
.into_data()
|
||||
.convert(),
|
||||
|
@ -662,14 +662,14 @@ mod tests {
|
|||
)
|
||||
.require_grad();
|
||||
let bias = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..self.channels[1], &device)
|
||||
TestTensorInt::arange(0..self.channels[1] as i64, &device)
|
||||
.into_data()
|
||||
.convert(),
|
||||
&device,
|
||||
)
|
||||
.require_grad();
|
||||
let x = TestAutodiffTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_x.num_elements(), &device)
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
|
||||
.reshape(shape_x)
|
||||
.into_data()
|
||||
.convert(),
|
||||
|
|
|
@ -418,4 +418,15 @@ impl<E: TchElement> IntTensorOps<Self> for LibTorch<E> {
|
|||
) -> Vec<TchTensor<i64, D>> {
|
||||
TchOps::chunk(tensor, chunks, dim)
|
||||
}
|
||||
|
||||
fn int_arange(range: Range<i64>, device: &LibTorchDevice) -> TchTensor<i64, 1> {
|
||||
let device: tch::Device = (*device).into();
|
||||
let mut tensor = tch::Tensor::arange(range.end - range.start, (tch::Kind::Int64, device));
|
||||
|
||||
if range.start != 0 {
|
||||
tensor = tensor.f_add_scalar_(range.start).unwrap();
|
||||
}
|
||||
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -42,20 +42,6 @@ impl<E: TchElement> FloatTensorOps<Self> for LibTorch<E> {
|
|||
}
|
||||
}
|
||||
|
||||
fn float_arange(range: Range<usize>, device: &LibTorchDevice) -> TchTensor<i64, 1> {
|
||||
let device: tch::Device = (*device).into();
|
||||
let mut tensor = tch::Tensor::arange(
|
||||
range.end as i64 - range.start as i64,
|
||||
(tch::Kind::Int64, device),
|
||||
);
|
||||
|
||||
if range.start != 0 {
|
||||
tensor = tensor.f_add_scalar_(range.start as i64).unwrap();
|
||||
}
|
||||
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
fn float_repeat<const D: usize>(
|
||||
tensor: TchTensor<E, D>,
|
||||
dim: usize,
|
||||
|
|
|
@ -11,8 +11,8 @@ where
|
|||
///
|
||||
/// * `range` - The range of values to generate.
|
||||
/// * `device` - The device to create the tensor on.
|
||||
pub fn arange(range: Range<usize>, device: &B::Device) -> Self {
|
||||
Tensor::new(B::float_arange(range, device))
|
||||
pub fn arange(range: Range<i64>, device: &B::Device) -> Self {
|
||||
Tensor::new(B::int_arange(range, device))
|
||||
}
|
||||
|
||||
/// Returns a new integer tensor on the specified device.
|
||||
|
@ -21,8 +21,8 @@ where
|
|||
///
|
||||
/// * `range` - The range of values to generate.
|
||||
/// * `step` - The step between each value.
|
||||
pub fn arange_step(range: Range<usize>, step: usize, device: &B::Device) -> Self {
|
||||
Tensor::new(B::float_arange_step(range, step, device))
|
||||
pub fn arange_step(range: Range<i64>, step: usize, device: &B::Device) -> Self {
|
||||
Tensor::new(B::int_arange_step(range, step, device))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -470,8 +470,8 @@ where
|
|||
let height = shape.dims[D - 2];
|
||||
let width = shape.dims[D - 1];
|
||||
|
||||
let row_indices: Tensor<B, 1, Int> = Tensor::arange(0..height, &self.device());
|
||||
let col_indices: Tensor<B, 1, Int> = Tensor::arange(0..width, &self.device());
|
||||
let row_indices: Tensor<B, 1, Int> = Tensor::arange(0..height as i64, &self.device());
|
||||
let col_indices: Tensor<B, 1, Int> = Tensor::arange(0..width as i64, &self.device());
|
||||
|
||||
let mut row_shape = [1; D];
|
||||
row_shape[D - 2] = height;
|
||||
|
@ -582,7 +582,7 @@ where
|
|||
///
|
||||
/// * `size` - The size of the square matrix.
|
||||
pub fn diagonal(size: usize, device: &B::Device) -> Self {
|
||||
let indices = Tensor::<B, 1, Int>::arange(0..size, device).unsqueeze();
|
||||
let indices = Tensor::<B, 1, Int>::arange(0..size as i64, device).unsqueeze();
|
||||
let ones = K::ones([1, size].into(), device);
|
||||
let zeros = K::zeros([size, size].into(), device);
|
||||
Self::new(K::scatter(0, zeros, indices, ones))
|
||||
|
|
|
@ -958,4 +958,43 @@ pub trait IntTensorOps<B: Backend> {
|
|||
) -> Vec<IntTensor<B, D>> {
|
||||
chunk::<B, D, Int>(tensor, chunks, dim)
|
||||
}
|
||||
|
||||
/// Creates a new tensor with values from the given range with the given step size.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `range` - The range of values.
|
||||
/// * `step` - The step size.
|
||||
/// * `device` - The device to create the tensor on.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor with the given values.
|
||||
fn int_arange_step(range: Range<i64>, step: usize, device: &Device<B>) -> IntTensor<B, 1> {
|
||||
let value = range
|
||||
.step_by(step)
|
||||
.map(|i| i.elem())
|
||||
.collect::<Vec<IntElem<B>>>();
|
||||
let shape = Shape::new([value.len()]);
|
||||
let data = Data::new(value, shape);
|
||||
B::int_from_data(data, device)
|
||||
}
|
||||
|
||||
/// Creates a new tensor with values from the given range.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `range` - The range of values.
|
||||
/// * `device` - The device to create the tensor on.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor with the given values.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// Uses `arange_step` with a step size of 1 under the hood.
|
||||
fn int_arange(range: Range<i64>, device: &Device<B>) -> IntTensor<B, 1> {
|
||||
Self::int_arange_step(range, 1, device)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -148,24 +148,6 @@ pub trait FloatTensorOps<B: Backend> {
|
|||
device: &Device<B>,
|
||||
) -> FloatTensor<B, D>;
|
||||
|
||||
/// Creates a new tensor with values from the given range.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `range` - The range of values.
|
||||
/// * `device` - The device to create the tensor on.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor with the given values.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// Uses `arange_step` with a step size of 1 under the hood.
|
||||
fn float_arange(range: Range<usize>, device: &Device<B>) -> IntTensor<B, 1> {
|
||||
Self::float_arange_step(range, 1, device)
|
||||
}
|
||||
|
||||
/// Converts float tensor to int tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
|
@ -177,27 +159,6 @@ pub trait FloatTensorOps<B: Backend> {
|
|||
/// The int tensor with the same data as the float tensor.
|
||||
fn float_into_int<const D: usize>(tensor: FloatTensor<B, D>) -> IntTensor<B, D>;
|
||||
|
||||
/// Creates a new tensor with values from the given range with the given step size.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `range` - The range of values.
|
||||
/// * `step` - The step size.
|
||||
/// * `device` - The device to create the tensor on.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor with the given values.
|
||||
fn float_arange_step(range: Range<usize>, step: usize, device: &Device<B>) -> IntTensor<B, 1> {
|
||||
let value = range
|
||||
.step_by(step)
|
||||
.map(|i| (i as i64).elem())
|
||||
.collect::<Vec<IntElem<B>>>();
|
||||
let shape = Shape::new([value.len()]);
|
||||
let data = Data::new(value, shape);
|
||||
B::int_from_data(data, device)
|
||||
}
|
||||
|
||||
/// Creates an empty tensor with the given shape.
|
||||
///
|
||||
/// # Arguments
|
||||
|
|
|
@ -58,7 +58,7 @@ mod tests {
|
|||
let shape_x = Shape::new([self.batch_size, self.channels, self.length]);
|
||||
let device = Default::default();
|
||||
let x = TestTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_x.num_elements(), &device)
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
|
||||
.reshape(shape_x)
|
||||
.into_data()
|
||||
.convert(),
|
||||
|
|
|
@ -90,7 +90,7 @@ mod tests {
|
|||
fn assert_output(self, y: TestTensor<4>) {
|
||||
let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]);
|
||||
let x = TestTensor::from(
|
||||
TestTensorInt::arange(0..shape_x.num_elements(), &y.device())
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &y.device())
|
||||
.reshape(shape_x)
|
||||
.into_data()
|
||||
.convert(),
|
||||
|
|
|
@ -69,7 +69,7 @@ mod tests {
|
|||
fn assert_output(self, y: TestTensor<3>) {
|
||||
let shape_x = Shape::new([self.batch_size, self.channels, self.length]);
|
||||
let x = TestTensor::from(
|
||||
TestTensorInt::arange(0..shape_x.num_elements(), &y.device())
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &y.device())
|
||||
.reshape(shape_x)
|
||||
.into_data()
|
||||
.convert(),
|
||||
|
|
|
@ -94,7 +94,7 @@ mod tests {
|
|||
fn assert_output(self, y: TestTensor<4>) {
|
||||
let shape_x = Shape::new([self.batch_size, self.channels, self.height, self.width]);
|
||||
let x = TestTensor::from(
|
||||
TestTensorInt::arange(0..shape_x.num_elements(), &y.device())
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &y.device())
|
||||
.reshape(shape_x)
|
||||
.into_data()
|
||||
.convert(),
|
||||
|
|
|
@ -110,20 +110,20 @@ mod tests {
|
|||
]);
|
||||
let device = Default::default();
|
||||
let weight = TestTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_weight.num_elements(), &device)
|
||||
TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
|
||||
.reshape(shape_weight)
|
||||
.into_data()
|
||||
.convert(),
|
||||
&device,
|
||||
);
|
||||
let bias = TestTensor::from_data(
|
||||
TestTensorInt::arange(0..self.channels_out, &device)
|
||||
TestTensorInt::arange(0..self.channels_out as i64, &device)
|
||||
.into_data()
|
||||
.convert(),
|
||||
&device,
|
||||
);
|
||||
let x = TestTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_x.num_elements(), &device)
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
|
||||
.reshape(shape_x)
|
||||
.into_data()
|
||||
.convert(),
|
||||
|
|
|
@ -132,18 +132,18 @@ mod tests {
|
|||
]);
|
||||
let device = Default::default();
|
||||
let weight = TestTensor::from(
|
||||
TestTensorInt::arange(0..shape_weight.num_elements(), &device)
|
||||
TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
|
||||
.reshape(shape_weight)
|
||||
.into_data()
|
||||
.convert(),
|
||||
);
|
||||
let bias = TestTensor::from(
|
||||
TestTensorInt::arange(0..self.channels_out, &device)
|
||||
TestTensorInt::arange(0..self.channels_out as i64, &device)
|
||||
.into_data()
|
||||
.convert(),
|
||||
);
|
||||
let x = TestTensor::from(
|
||||
TestTensorInt::arange(0..shape_x.num_elements(), &device)
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
|
||||
.reshape(shape_x)
|
||||
.into_data()
|
||||
.convert(),
|
||||
|
|
|
@ -112,20 +112,20 @@ mod tests {
|
|||
]);
|
||||
let device = Default::default();
|
||||
let weights = TestTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_weights.num_elements(), &device)
|
||||
TestTensorInt::arange(0..shape_weights.num_elements() as i64, &device)
|
||||
.reshape(shape_weights)
|
||||
.into_data()
|
||||
.convert(),
|
||||
&device,
|
||||
);
|
||||
let bias = TestTensor::from_data(
|
||||
TestTensorInt::arange(0..self.channels_out, &device)
|
||||
TestTensorInt::arange(0..self.channels_out as i64, &device)
|
||||
.into_data()
|
||||
.convert(),
|
||||
&device,
|
||||
);
|
||||
let x = TestTensor::from_data(
|
||||
TestTensorInt::arange(0..shape_x.num_elements(), &device)
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
|
||||
.reshape(shape_x)
|
||||
.into_data()
|
||||
.convert(),
|
||||
|
|
|
@ -302,18 +302,18 @@ mod tests {
|
|||
]);
|
||||
let device = Default::default();
|
||||
let weights = TestTensor::from(
|
||||
TestTensorInt::arange(0..shape_weights.num_elements(), &device)
|
||||
TestTensorInt::arange(0..shape_weights.num_elements() as i64, &device)
|
||||
.reshape(shape_weights)
|
||||
.into_data()
|
||||
.convert(),
|
||||
);
|
||||
let bias = TestTensor::from(
|
||||
TestTensorInt::arange(0..self.channels_out, &device)
|
||||
TestTensorInt::arange(0..self.channels_out as i64, &device)
|
||||
.into_data()
|
||||
.convert(),
|
||||
);
|
||||
let x = TestTensor::from(
|
||||
TestTensorInt::arange(0..shape_x.num_elements(), &device)
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
|
||||
.reshape(shape_x)
|
||||
.into_data()
|
||||
.convert(),
|
||||
|
|
|
@ -90,7 +90,7 @@ mod tests {
|
|||
fn assert_shape(self, expected_shape: [usize; 3]) {
|
||||
let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]);
|
||||
let x = TestTensor::from(
|
||||
TestTensorInt::arange(0..shape_x.num_elements(), &Default::default())
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &Default::default())
|
||||
.reshape(shape_x)
|
||||
.into_data()
|
||||
.convert(),
|
||||
|
@ -112,7 +112,7 @@ mod tests {
|
|||
fn assert_output(self, expected: TestTensor<3>) {
|
||||
let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]);
|
||||
let x = TestTensor::from(
|
||||
TestTensorInt::arange(0..shape_x.num_elements(), &Default::default())
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &Default::default())
|
||||
.reshape(shape_x)
|
||||
.into_data()
|
||||
.convert(),
|
||||
|
|
|
@ -7,8 +7,20 @@ mod tests {
|
|||
#[test]
|
||||
fn test_arange() {
|
||||
let device = <TestBackend as Backend>::Device::default();
|
||||
|
||||
let tensor = Tensor::<TestBackend, 1, Int>::arange(2..5, &device);
|
||||
assert_eq!(tensor.clone().into_data(), Data::from([2, 3, 4]));
|
||||
assert_eq!(tensor.into_data(), Data::from([2, 3, 4]));
|
||||
|
||||
// Test arange with negative numbers
|
||||
let tensor = Tensor::<TestBackend, 1, Int>::arange(-10..-5, &device);
|
||||
assert_eq!(tensor.into_data(), Data::from([-10, -9, -8, -7, -6]));
|
||||
|
||||
let tensor = Tensor::<TestBackend, 1, Int>::arange(-3..0, &device);
|
||||
assert_eq!(tensor.into_data(), Data::from([-3, -2, -1]));
|
||||
|
||||
// Test arange with a mix of positive and negative numbers
|
||||
let tensor = Tensor::<TestBackend, 1, Int>::arange(-2..3, &device);
|
||||
assert_eq!(tensor.clone().into_data(), Data::from([-2, -1, 0, 1, 2]));
|
||||
assert_eq!(tensor.device(), device);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,7 +18,14 @@ mod tests {
|
|||
|
||||
// Test correct sequence of numbers when the range is 0..2 and the step is 5
|
||||
let tensor = Tensor::<TestBackend, 1, Int>::arange_step(0..2, 5, &device);
|
||||
assert_eq!(tensor.clone().into_data(), Data::from([0]));
|
||||
assert_eq!(tensor.into_data(), Data::from([0]));
|
||||
|
||||
// Test correct sequence of numbers when the range includes negative numbers
|
||||
let tensor = Tensor::<TestBackend, 1, Int>::arange_step(-3..3, 2, &device);
|
||||
assert_eq!(tensor.into_data(), Data::from([-3, -1, 1]));
|
||||
|
||||
let tensor = Tensor::<TestBackend, 1, Int>::arange_step(-5..1, 5, &device);
|
||||
assert_eq!(tensor.clone().into_data(), Data::from([-5, 0]));
|
||||
assert_eq!(tensor.device(), device);
|
||||
}
|
||||
|
||||
|
|
|
@ -71,7 +71,7 @@ mod tests {
|
|||
const END: usize = 100;
|
||||
|
||||
let device = Default::default();
|
||||
let tensor = Tensor::<TestBackend, 1, Int>::arange(START..END, &device);
|
||||
let tensor = Tensor::<TestBackend, 1, Int>::arange(START as i64..END as i64, &device);
|
||||
let tensor_float = cast::<i32, f32, 1>(tensor.clone().into_primitive());
|
||||
|
||||
let data_int = tensor.into_data();
|
||||
|
|
|
@ -96,7 +96,7 @@ impl<B: Backend> TextClassificationModel<B> {
|
|||
let mask_pad = item.mask_pad.to_device(device);
|
||||
|
||||
// Calculate token and position embeddings, and combine them
|
||||
let index_positions = Tensor::arange(0..seq_length, device)
|
||||
let index_positions = Tensor::arange(0..seq_length as i64, device)
|
||||
.reshape([1, seq_length])
|
||||
.repeat(0, batch_size);
|
||||
let embedding_positions = self.embedding_pos.forward(index_positions);
|
||||
|
@ -136,7 +136,7 @@ impl<B: Backend> TextClassificationModel<B> {
|
|||
let mask_pad = item.mask_pad.to_device(device);
|
||||
|
||||
// Calculate token and position embeddings, and combine them
|
||||
let index_positions = Tensor::arange(0..seq_length, device)
|
||||
let index_positions = Tensor::arange(0..seq_length as i64, device)
|
||||
.reshape([1, seq_length])
|
||||
.repeat(0, batch_size);
|
||||
let embedding_positions = self.embedding_pos.forward(index_positions);
|
||||
|
|
|
@ -64,7 +64,7 @@ impl<B: Backend> TextGenerationModel<B> {
|
|||
let targets = item.targets.to_device(device);
|
||||
let mask_pad = item.mask_pad.to_device(device);
|
||||
|
||||
let index_positions = Tensor::arange(0..seq_length, device)
|
||||
let index_positions = Tensor::arange(0..seq_length as i64, device)
|
||||
.reshape([1, seq_length])
|
||||
.repeat(0, batch_size);
|
||||
|
||||
|
|
Loading…
Reference in New Issue