cargo +nightly fmt (#1017)

This commit is contained in:
Alex Errant 2023-12-12 12:29:06 -06:00 committed by GitHub
parent 1a5f252ac6
commit 610d64095e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 424 additions and 323 deletions

View File

@ -16,9 +16,10 @@ impl BreadthFirstSearch {
let mut visited = HashSet::with_capacity(root.order);
let mut parents = Vec::with_capacity(root.order);
let mut steps = graph.steps();
let root_step = steps
.remove(&root.id)
.expect("Root node should have a step registered, did you forget to call `Tensor::register_grad` on the tensor where you need gradients?");
let root_step = steps.remove(&root.id).expect(
"Root node should have a step registered, did you forget to call \
`Tensor::register_grad` on the tensor where you need gradients?",
);
visited.insert(root.id.clone());
parents.append(&mut root.parents.clone());

View File

@ -103,49 +103,45 @@ impl<B: Backend> ModuleOps<Autodiff<B>> for Autodiff<B> {
}
match bias {
Some(bias) => {
match Conv2DWithBias
.prepare(
[x.node, weight.node, bias.node],
[x.graph, weight.graph, bias.graph],
)
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
bias.primitive.clone(),
options.clone(),
),
B::conv2d(x.primitive, weight.primitive, Some(bias.primitive), options),
Some(bias) => match Conv2DWithBias
.prepare(
[x.node, weight.node, bias.node],
[x.graph, weight.graph, bias.graph],
)
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
bias.primitive.clone(),
options.clone(),
),
OpsKind::UnTracked(prep) => prep.finish(B::conv2d(
x.primitive,
weight.primitive,
Some(bias.primitive),
options,
)),
}
}
None => {
match Conv2DNoBias
.prepare([x.node, weight.node], [x.graph, weight.graph])
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
options.clone(),
),
B::conv2d(x.primitive, weight.primitive, None, options),
B::conv2d(x.primitive, weight.primitive, Some(bias.primitive), options),
),
OpsKind::UnTracked(prep) => prep.finish(B::conv2d(
x.primitive,
weight.primitive,
Some(bias.primitive),
options,
)),
},
None => match Conv2DNoBias
.prepare([x.node, weight.node], [x.graph, weight.graph])
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
options.clone(),
),
OpsKind::UnTracked(prep) => {
prep.finish(B::conv2d(x.primitive, weight.primitive, None, options))
}
B::conv2d(x.primitive, weight.primitive, None, options),
),
OpsKind::UnTracked(prep) => {
prep.finish(B::conv2d(x.primitive, weight.primitive, None, options))
}
}
},
}
}
@ -211,57 +207,53 @@ impl<B: Backend> ModuleOps<Autodiff<B>> for Autodiff<B> {
}
match bias {
Some(bias) => {
match ConvTranspose2DWithBias
.prepare(
[x.node, weight.node, bias.node],
[x.graph, weight.graph, bias.graph],
)
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
bias.primitive.clone(),
options.clone(),
),
B::conv_transpose2d(
x.primitive,
weight.primitive,
Some(bias.primitive),
options,
),
Some(bias) => match ConvTranspose2DWithBias
.prepare(
[x.node, weight.node, bias.node],
[x.graph, weight.graph, bias.graph],
)
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
bias.primitive.clone(),
options.clone(),
),
OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose2d(
B::conv_transpose2d(
x.primitive,
weight.primitive,
Some(bias.primitive),
options,
)),
}
}
None => {
match ConvTranspose2DNoBias
.prepare([x.node, weight.node], [x.graph, weight.graph])
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
options.clone(),
),
B::conv_transpose2d(x.primitive, weight.primitive, None, options),
),
OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose2d(
x.primitive,
weight.primitive,
None,
options,
)),
}
}
),
OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose2d(
x.primitive,
weight.primitive,
Some(bias.primitive),
options,
)),
},
None => match ConvTranspose2DNoBias
.prepare([x.node, weight.node], [x.graph, weight.graph])
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
options.clone(),
),
B::conv_transpose2d(x.primitive, weight.primitive, None, options),
),
OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose2d(
x.primitive,
weight.primitive,
None,
options,
)),
},
}
}
@ -322,49 +314,45 @@ impl<B: Backend> ModuleOps<Autodiff<B>> for Autodiff<B> {
}
}
match bias {
Some(bias) => {
match Conv1DWithBias
.prepare(
[x.node, weight.node, bias.node],
[x.graph, weight.graph, bias.graph],
)
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
bias.primitive.clone(),
options.clone(),
),
B::conv1d(x.primitive, weight.primitive, Some(bias.primitive), options),
Some(bias) => match Conv1DWithBias
.prepare(
[x.node, weight.node, bias.node],
[x.graph, weight.graph, bias.graph],
)
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
bias.primitive.clone(),
options.clone(),
),
OpsKind::UnTracked(prep) => prep.finish(B::conv1d(
x.primitive,
weight.primitive,
Some(bias.primitive),
options,
)),
}
}
None => {
match Conv1DNoBias
.prepare([x.node, weight.node], [x.graph, weight.graph])
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
options.clone(),
),
B::conv1d(x.primitive, weight.primitive, None, options),
B::conv1d(x.primitive, weight.primitive, Some(bias.primitive), options),
),
OpsKind::UnTracked(prep) => prep.finish(B::conv1d(
x.primitive,
weight.primitive,
Some(bias.primitive),
options,
)),
},
None => match Conv1DNoBias
.prepare([x.node, weight.node], [x.graph, weight.graph])
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
options.clone(),
),
OpsKind::UnTracked(prep) => {
prep.finish(B::conv1d(x.primitive, weight.primitive, None, options))
}
B::conv1d(x.primitive, weight.primitive, None, options),
),
OpsKind::UnTracked(prep) => {
prep.finish(B::conv1d(x.primitive, weight.primitive, None, options))
}
}
},
}
}
@ -430,57 +418,53 @@ impl<B: Backend> ModuleOps<Autodiff<B>> for Autodiff<B> {
}
match bias {
Some(bias) => {
match ConvTranspose1DWithBias
.prepare(
[x.node, weight.node, bias.node],
[x.graph, weight.graph, bias.graph],
)
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
bias.primitive.clone(),
options.clone(),
),
B::conv_transpose1d(
x.primitive,
weight.primitive,
Some(bias.primitive),
options,
),
Some(bias) => match ConvTranspose1DWithBias
.prepare(
[x.node, weight.node, bias.node],
[x.graph, weight.graph, bias.graph],
)
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
bias.primitive.clone(),
options.clone(),
),
OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose1d(
B::conv_transpose1d(
x.primitive,
weight.primitive,
Some(bias.primitive),
options,
)),
}
}
None => {
match ConvTranspose1DNoBias
.prepare([x.node, weight.node], [x.graph, weight.graph])
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
options.clone(),
),
B::conv_transpose1d(x.primitive, weight.primitive, None, options),
),
OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose1d(
x.primitive,
weight.primitive,
None,
options,
)),
}
}
),
OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose1d(
x.primitive,
weight.primitive,
Some(bias.primitive),
options,
)),
},
None => match ConvTranspose1DNoBias
.prepare([x.node, weight.node], [x.graph, weight.graph])
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
(
x.primitive.clone(),
weight.primitive.clone(),
options.clone(),
),
B::conv_transpose1d(x.primitive, weight.primitive, None, options),
),
OpsKind::UnTracked(prep) => prep.finish(B::conv_transpose1d(
x.primitive,
weight.primitive,
None,
options,
)),
},
}
}

View File

@ -3,6 +3,9 @@ pub(crate) fn checks_channels_div_groups(channels_in: usize, channels_out: usize
let channels_out_div_by_group = channels_out % groups == 0;
if !channels_in_div_by_group && !channels_out_div_by_group {
panic!("Both channels must be divisible by the number of groups. Got channels_in={channels_in}, channels_out={channels_out}, groups={groups}");
panic!(
"Both channels must be divisible by the number of groups. Got \
channels_in={channels_in}, channels_out={channels_out}, groups={groups}"
);
}
}

View File

@ -133,10 +133,12 @@ impl Initializer {
fn xavier_std(&self, fan_in: Option<usize>, fan_out: Option<usize>) -> f64 {
let fan_in = fan_in.expect(
"Can't use Xavier initialization without specifying fan in. Use init_with method and provide fan_in.",
"Can't use Xavier initialization without specifying fan in. Use init_with method and \
provide fan_in.",
);
let fan_out = fan_out.expect(
"Can't use Xavier initialization without specifying fan out. Use init_with method and provide fan_out.",
"Can't use Xavier initialization without specifying fan out. Use init_with method and \
provide fan_out.",
);
sqrt(2.0 / (fan_in + fan_out) as f64)
}

View File

@ -45,10 +45,11 @@ impl BinaryCrossEntropyLossConfig {
fn assertions(&self) {
if let Some(alpha) = self.smoothing {
assert!(
(0.0..=1.).contains(&alpha),
"Alpha of Cross-entropy loss with smoothed labels should be in interval [0, 1]. Got {}",
alpha
);
(0.0..=1.).contains(&alpha),
"Alpha of Cross-entropy loss with smoothed labels should be in interval [0, 1]. \
Got {}",
alpha
);
};
if let Some(weights) = self.weights.as_ref() {
assert!(

View File

@ -53,10 +53,11 @@ impl CrossEntropyLossConfig {
fn assertions(&self) {
if let Some(alpha) = self.smoothing {
assert!(
(0.0..=1.).contains(&alpha),
"Alpha of Cross-entropy loss with smoothed labels should be in interval [0, 1]. Got {}",
alpha
);
(0.0..=1.).contains(&alpha),
"Alpha of Cross-entropy loss with smoothed labels should be in interval [0, 1]. \
Got {}",
alpha
);
};
if let Some(weights) = self.weights.as_ref() {
assert!(

View File

@ -78,7 +78,13 @@ impl<const D: usize, B: Backend> BatchNorm<B, D> {
// Should be move to a compilation error when const generic support that kind of
// validation. https://github.com/rust-lang/rust/issues/76560
if D + 2 != DI {
panic!("BatchNorm{}D can only be applied on tensors of size {} with the following shape [batch_size, channels, ...], received {}D tensor", D, D+2, DI);
panic!(
"BatchNorm{}D can only be applied on tensors of size {} with the following shape \
[batch_size, channels, ...], received {}D tensor",
D,
D + 2,
DI
);
}
match B::ad_enabled() {

View File

@ -156,7 +156,7 @@ mod tests {
use crate::optim::{GradientsParams, Optimizer};
use crate::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
use crate::tensor::{Data, Distribution, Tensor};
use crate::{nn, TestAutodiffBackend, TestBackend};
use crate::{nn, nn::Linear, TestAutodiffBackend, TestBackend};
const LEARNING_RATE: LearningRate = 0.01;
@ -262,7 +262,7 @@ mod tests {
}
fn create_adagrad(
) -> OptimizerAdaptor<AdaGrad<TestBackend>, nn::Linear<TestAutodiffBackend>, TestAutodiffBackend>
) -> OptimizerAdaptor<AdaGrad<TestBackend>, Linear<TestAutodiffBackend>, TestAutodiffBackend>
{
let config = AdaGradConfig::new();
AdaGrad {

View File

@ -317,7 +317,7 @@ mod tests {
use crate::optim::{GradientsParams, Optimizer};
use crate::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
use crate::tensor::{Data, Distribution, Tensor};
use crate::{nn, TestAutodiffBackend, TestBackend};
use crate::{nn, nn::Linear, TestAutodiffBackend, TestBackend};
use tempfile::TempDir;
const LEARNING_RATE: LearningRate = 0.01;
@ -509,7 +509,7 @@ mod tests {
}
fn create_rmsprop(
) -> OptimizerAdaptor<RMSProp<TestBackend>, nn::Linear<TestAutodiffBackend>, TestAutodiffBackend>
) -> OptimizerAdaptor<RMSProp<TestBackend>, Linear<TestAutodiffBackend>, TestAutodiffBackend>
{
RMSPropConfig {
alpha: 0.99,

View File

@ -599,8 +599,9 @@ where
let conn_pool = self.conn_pool.as_ref().unwrap();
let connection = conn_pool.get()?;
let create_table_statement = format!(
"create table if not exists {split} (row_id integer primary key autoincrement not null, item blob not null)"
);
"create table if not exists {split} (row_id integer primary key autoincrement not \
null, item blob not null)"
);
connection.execute(create_table_statement.as_str(), [])?;

View File

@ -108,7 +108,8 @@ fn constant_impl(ast: &syn::DeriveInput) -> TokenStream {
burn::constant!(module);
}
impl #generics_module_ad burn::module::AutodiffModule<B> for #name #generics_ty #generics_where {
impl #generics_module_ad burn::module::AutodiffModule<B>
for #name #generics_ty #generics_where {
burn::constant!(ad_module, #name #generics_ty);
}
};

View File

@ -47,7 +47,10 @@ impl ModuleCodegen for StructModuleCodegen {
});
quote! {
fn collect_devices(&self, devices: burn::module::Devices<B>) -> burn::module::Devices<B> {
fn collect_devices(
&self,
devices: burn::module::Devices<B>
) -> burn::module::Devices<B> {
#body
devices

View File

@ -304,7 +304,11 @@ pub(crate) mod tests {
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, tensor1: Tensor<B, 4>, tensor2: Tensor<B, 4>) -> Tensor<B, 4> {
pub fn forward(
&self,
tensor1: Tensor<B, 4>,
tensor2: Tensor<B, 4>
) -> Tensor<B, 4> {
let tensor3 = tensor1.matmul(tensor2);
let tensor4 = self.conv2d.forward(tensor3);
@ -376,7 +380,11 @@ pub(crate) mod tests {
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, tensor1: Tensor<B, 4>, tensor2: Tensor<B, 4>) -> Tensor<B, 4> {
pub fn forward(
&self,
tensor1: Tensor<B, 4>,
tensor2: Tensor<B, 4>
) -> Tensor<B, 4> {
let tensor3 = tensor1.matmul(tensor2.clone());
let tensor4 = self.conv2d.forward(tensor2);
let output = tensor3.matmul(tensor4);

View File

@ -70,7 +70,8 @@ macro_rules! batch_norm_serialize {
}};
($self:expr, $serializer:expr, $dim:expr) => {{
let record: BatchNormRecord<SerializationBackend, $dim> = batch_norm_serialize!(record $self);
let record: BatchNormRecord<SerializationBackend, $dim> =
batch_norm_serialize!(record $self);
let item = Record::into_item::<PS>(record);
item.serialize($serializer)

View File

@ -326,7 +326,11 @@ mod tests {
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, tensor1: Tensor<B, 4>, tensor2: Tensor<B, 4>) -> Tensor<B, 4, Bool> {
pub fn forward(
&self,
tensor1: Tensor<B, 4>,
tensor2: Tensor<B, 4>
) -> Tensor<B, 4, Bool> {
let tensor3 = tensor1.equal(tensor2);
tensor3

View File

@ -93,7 +93,11 @@ mod tests {
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, tensor1: Tensor<B, 4>, tensor2: Tensor<B, 4>) -> Tensor<B, 4> {
pub fn forward(
&self,
tensor1: Tensor<B, 4>,
tensor2: Tensor<B, 4>
) -> Tensor<B, 4> {
let tensor3 = burn::tensor::Tensor::cat([tensor1, tensor2].into(), 1);
tensor3

View File

@ -93,7 +93,11 @@ mod tests {
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, tensor1: Tensor<B, 2>, tensor2: Tensor<B, 2, Int>) -> Tensor<B, 2> {
pub fn forward(
&self,
tensor1: Tensor<B, 2>,
tensor2: Tensor<B, 2, Int>
) -> Tensor<B, 2> {
let tensor3 = tensor1.gather(1, tensor2);
tensor3

View File

@ -85,7 +85,11 @@ mod tests {
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, tensor1: Tensor<B, 4>, tensor2: Tensor<B, 4>) -> Tensor<B, 4> {
pub fn forward(
&self,
tensor1: Tensor<B, 4>,
tensor2: Tensor<B, 4>
) -> Tensor<B, 4> {
let tensor3 = tensor1.matmul(tensor2);
tensor3

View File

@ -284,7 +284,11 @@ where
let batch_size = Self::gather_batch_size(&shape_tensor, &shape_indices);
if shape_value != shape_indices {
panic!("Invalid dimension: the shape of the index tensor should be the same as the value tensor: Index {:?} value {:?}", shape_indices.dims, shape_value.dims);
panic!(
"Invalid dimension: the shape of the index tensor should be the same as the value \
tensor: Index {:?} value {:?}",
shape_indices.dims, shape_value.dims
);
}
let indices = NdArrayOps::reshape(indices, Shape::new([batch_size, size_index])).array;
@ -355,9 +359,10 @@ where
for i in 0..D - 1 {
if shape_tensor.dims[i] != shape_indices.dims[i] {
panic!(
"Unsupported dimension, only the last dimension can differ: Tensor {:?} Index {:?}",
shape_tensor.dims, shape_indices.dims
);
"Unsupported dimension, only the last dimension can differ: Tensor {:?} Index \
{:?}",
shape_tensor.dims, shape_indices.dims
);
}
batch_size *= shape_indices.dims[i];
}

View File

@ -102,8 +102,10 @@ impl TensorCheck {
check = check.register(
"Narrow",
TensorError::new(format!(
"Can't narrow at dimension {}, start exceeds the size of the tensor along this dimension (Size={})",
dim, tensor.shape().dims[dim]
"Can't narrow at dimension {}, start exceeds the size of the tensor along \
this dimension (Size={})",
dim,
tensor.shape().dims[dim]
)),
);
}
@ -112,8 +114,10 @@ impl TensorCheck {
check = check.register(
"Narrow",
TensorError::new(format!(
"Can't narrow at dimension {}, start + length exceeds the size of the tensor along this dimension (Size={})",
dim, tensor.shape().dims[dim]
"Can't narrow at dimension {}, start + length exceeds the size of the tensor \
along this dimension (Size={})",
dim,
tensor.shape().dims[dim]
)),
);
}
@ -129,15 +133,16 @@ impl TensorCheck {
if original.num_elements() != target.num_elements() {
check = check.register(
"Reshape",
TensorError::new(
"The given shape doesn't have the same number of elements as the current tensor.",
)
.details(format!(
"Current shape: {:?}, target shape: {:?}.",
original.dims, target.dims
)),
);
"Reshape",
TensorError::new(
"The given shape doesn't have the same number of elements as the current \
tensor.",
)
.details(format!(
"Current shape: {:?}, target shape: {:?}.",
original.dims, target.dims
)),
);
}
check
@ -202,8 +207,8 @@ impl TensorCheck {
check = check.register(
"Flatten",
TensorError::new(format!(
"The destination dimension ({D2}) must be large enough to accommodate the flattening operation."
"The destination dimension ({D2}) must be large enough to accommodate the \
flattening operation."
)),
);
}
@ -310,8 +315,9 @@ impl TensorCheck {
check = check.register(
"Matmul",
TensorError::new(format!(
"The inner dimension of matmul should be the same, but got {dim_lhs} and {dim_rhs}."
))
"The inner dimension of matmul should be the same, but got {dim_lhs} and \
{dim_rhs}."
))
.details(format!(
"Lhs shape {:?}, rhs shape {:?}.",
shape_lhs.dims, shape_rhs.dims
@ -402,16 +408,17 @@ impl TensorCheck {
if shape_reference != shape {
return check.register(
"Cat",
TensorError::new(
"Can't concatenate tensors with different shapes, except for the provided dimension",
)
.details(format!(
"Provided dimension ({}), tensors shapes: {:?}",
dim,
tensors.iter().map(Tensor::shape).collect::<Vec<_>>()
)),
);
"Cat",
TensorError::new(
"Can't concatenate tensors with different shapes, except for the provided \
dimension",
)
.details(format!(
"Provided dimension ({}), tensors shapes: {:?}",
dim,
tensors.iter().map(Tensor::shape).collect::<Vec<_>>()
)),
);
}
}
@ -427,13 +434,18 @@ impl TensorCheck {
let n_dims_ranges = D2;
if n_dims_tensor < n_dims_ranges {
check = check.register("Slice",
TensorError::new ("The provided ranges array has a higher number of dimensions than the current tensor.")
.details(
format!(
"The ranges array must be smaller or equal to the tensor number of dimensions. \
Tensor number of dimensions: {n_dims_tensor}, ranges array length {n_dims_ranges}."
)));
check = check.register(
"Slice",
TensorError::new(
"The provided ranges array has a higher number of dimensions than the current \
tensor.",
)
.details(format!(
"The ranges array must be smaller or equal to the tensor number of \
dimensions. Tensor number of dimensions: {n_dims_tensor}, ranges array \
length {n_dims_ranges}."
)),
);
}
for i in 0..usize::min(D1, D2) {
@ -442,31 +454,32 @@ impl TensorCheck {
if range.end > d_tensor {
check = check.register(
"Slice",
TensorError::new(
"The provided ranges array has a range that exceeds the current tensor size.",
)
.details(format!(
"The range ({}..{}) exceeds the size of the tensor ({}) at dimension {}. \
Tensor shape {:?}, provided ranges {:?}.",
range.start, range.end, d_tensor, i, shape.dims, ranges,
)),
);
"Slice",
TensorError::new(
"The provided ranges array has a range that exceeds the current tensor \
size.",
)
.details(format!(
"The range ({}..{}) exceeds the size of the tensor ({}) at dimension {}. \
Tensor shape {:?}, provided ranges {:?}.",
range.start, range.end, d_tensor, i, shape.dims, ranges,
)),
);
}
if range.start >= range.end {
check = check.register(
"Slice",
TensorError::new("The provided range array has a range where the start index is bigger or equal to its end.")
TensorError::new(
"The provided range array has a range where the start index is bigger or \
equal to its end.",
)
.details(format!(
"The range at dimension '{}' starts at '{}' and is greater or equal to its end '{}'. \
Tensor shape {:?}, provided ranges {:?}.",
i,
range.start,
range.end,
shape.dims,
ranges,
)));
"The range at dimension '{}' starts at '{}' and is greater or equal to \
its end '{}'. Tensor shape {:?}, provided ranges {:?}.",
i, range.start, range.end, shape.dims, ranges,
)),
);
}
}
@ -482,15 +495,16 @@ impl TensorCheck {
if D1 < D2 {
check = check.register(
"Slice Assign",
TensorError::new(
"The provided ranges array has a higher number of dimensions than the current tensor.",
)
.details(format!(
"The ranges array must be smaller or equal to the tensor number of dimensions. \
Tensor number of dimensions: {D1}, ranges array length {D2}."
)),
);
"Slice Assign",
TensorError::new(
"The provided ranges array has a higher number of dimensions than the current \
tensor.",
)
.details(format!(
"The ranges array must be smaller or equal to the tensor number of \
dimensions. Tensor number of dimensions: {D1}, ranges array length {D2}."
)),
);
}
for i in 0..usize::min(D1, D2) {
@ -500,25 +514,30 @@ impl TensorCheck {
if range.end > d_tensor {
check = check.register(
"Range Assign",
TensorError::new(
"The provided ranges array has a range that exceeds the current tensor size.",
)
.details(format!(
"The range ({}..{}) exceeds the size of the tensor ({}) at dimension {}. \
Current tensor shape {:?}, value tensor shape {:?}, provided ranges {:?}.",
range.start, range.end, d_tensor, i, shape.dims, shape_value.dims, ranges,
)),
);
"Range Assign",
TensorError::new(
"The provided ranges array has a range that exceeds the current tensor \
size.",
)
.details(format!(
"The range ({}..{}) exceeds the size of the tensor ({}) at dimension {}. \
Current tensor shape {:?}, value tensor shape {:?}, provided ranges {:?}.",
range.start, range.end, d_tensor, i, shape.dims, shape_value.dims, ranges,
)),
);
}
if range.end - range.start != d_tensor_value {
check = check.register(
"Slice Assign",
TensorError::new("The value tensor must match the amount of elements selected with the ranges array")
TensorError::new(
"The value tensor must match the amount of elements selected with the \
ranges array",
)
.details(format!(
"The range ({}..{}) doesn't match the number of elements of the value tensor ({}) at dimension {}. \
Current tensor shape {:?}, value tensor shape {:?}, provided ranges {:?}.",
"The range ({}..{}) doesn't match the number of elements of the value \
tensor ({}) at dimension {}. Current tensor shape {:?}, value tensor \
shape {:?}, provided ranges {:?}.",
range.start,
range.end,
d_tensor_value,
@ -526,23 +545,24 @@ impl TensorCheck {
shape.dims,
shape_value.dims,
ranges,
)));
)),
);
}
if range.start >= range.end {
check = check.register(
"Slice Assign",
TensorError::new("The provided ranges array has a range where the start index is bigger or equal to its end.")
TensorError::new(
"The provided ranges array has a range where the start index is bigger or \
equal to its end.",
)
.details(format!(
"The range at dimension '{}' starts at '{}' and is greater or equal to its end '{}'. \
Current tensor shape {:?}, value tensor shape {:?}, provided ranges {:?}.",
i,
range.start,
range.end,
shape.dims,
shape_value.dims,
ranges,
)));
"The range at dimension '{}' starts at '{}' and is greater or equal to \
its end '{}'. Current tensor shape {:?}, value tensor shape {:?}, \
provided ranges {:?}.",
i, range.start, range.end, shape.dims, shape_value.dims, ranges,
)),
);
}
}
@ -703,10 +723,10 @@ impl TensorCheck {
ops,
TensorError::new("The provided tensors have incompatible shapes.").details(
format!(
"Incompatible size at dimension '{}' => '{} != {}', which can't be broadcasted. \
Lhs tensor shape {:?}, Rhs tensor shape {:?}.",
i, d_lhs, d_rhs, lhs.dims, rhs.dims,
),
"Incompatible size at dimension '{}' => '{} != {}', which can't be \
broadcasted. Lhs tensor shape {:?}, Rhs tensor shape {:?}.",
i, d_lhs, d_rhs, lhs.dims, rhs.dims,
),
),
);
}

View File

@ -320,9 +320,11 @@ impl<E: Into<f64> + Clone + core::fmt::Debug + PartialEq, const D: usize> Data<E
if err > tolerance {
// Only print the first 5 different values.
if num_diff < max_num_diff {
message +=
format!("\n => Position {i}: {a} != {b} | difference {err} > tolerance {tolerance}")
.as_str();
message += format!(
"\n => Position {i}: {a} != {b} | difference {err} > tolerance \
{tolerance}"
)
.as_str();
}
num_diff += 1;
}

View File

@ -63,7 +63,16 @@ impl EarlyStoppingStrategy for MetricEarlyStoppingStrategy {
let should_stop = epoch - self.best_epoch >= n_epochs;
if should_stop {
log::info!("Stopping training loop, no improvement since epoch {}, {}: {}, current epoch {}, {}: {}", self.best_epoch, self.metric_name, self.best_value, epoch, self.metric_name, current_value);
log::info!(
"Stopping training loop, no improvement since epoch {}, {}: {}, current \
epoch {}, {}: {}",
self.best_epoch,
self.metric_name,
self.best_value,
epoch,
self.metric_name,
current_value
);
}
should_stop

View File

@ -39,7 +39,8 @@ fn update_panic_hook(file_path: &str) {
std::panic::set_hook(Box::new(move |info| {
log::error!("PANIC => {}", info.to_string());
eprintln!(
"=== PANIC ===\nA fatal error happened, you can check the experiment logs here => '{file_path}'\n============="
"=== PANIC ===\nA fatal error happened, you can check the experiment logs here => \
'{file_path}'\n============="
);
hook(info);
}));

View File

@ -145,23 +145,32 @@ impl TuiMetricsRenderer {
if let Event::Key(key) = event {
if let KeyCode::Char('q') = key.code {
self.popup = PopupState::Full(
"Quit".to_string(),
vec![
Callback::new(
"Stop the training.",
"Stop the training immediately. This will break from the training loop, but any remaining code after the loop will be executed.",
's',
QuitPopupAccept(self.interuptor.clone()),
),
Callback::new(
"Stop the training immediately.",
"Kill the program. This will create a panic! which will make the current training fails. Any code following the training won't be executed.",
'k',
KillPopupAccept,
),
Callback::new("Cancel", "Cancel the action, continue the training.", 'c', PopupCancel),
],
);
"Quit".to_string(),
vec![
Callback::new(
"Stop the training.",
"Stop the training immediately. This will break from the \
training loop, but any remaining code after the loop will be \
executed.",
's',
QuitPopupAccept(self.interuptor.clone()),
),
Callback::new(
"Stop the training immediately.",
"Kill the program. This will create a panic! which will make \
the current training fails. Any code following the training \
won't be executed.",
'k',
KillPopupAccept,
),
Callback::new(
"Cancel",
"Cancel the action, continue the training.",
'c',
PopupCancel,
),
],
);
}
}
}

View File

@ -89,7 +89,10 @@ impl GraphicsApi for AutoGraphicsApi {
"opengl" => return wgpu::Backend::Gl,
"webgpu" => return wgpu::Backend::BrowserWebGpu,
_ => {
eprintln!("Invalid graphics backend specified in GRAPHICS_BACKEND environment variable");
eprintln!(
"Invalid graphics backend specified in GRAPHICS_BACKEND environment \
variable"
);
std::process::exit(1);
}
}

View File

@ -14,9 +14,19 @@ pub fn launch<B: AutodiffBackend>(device: B::Device) {
"/tmp/text-classification-ag-news",
// Samples from the test dataset, but you are free to test with your own text.
vec![
"Jays power up to take finale Contrary to popular belief, the power never really snapped back at SkyDome on Sunday. The lights came on after an hour delay, but it took some extra time for the batting orders to provide some extra wattage.".to_string(),
"Yemen Sentences 15 Militants on Terror Charges A court in Yemen has sentenced one man to death and 14 others to prison terms for a series of attacks and terrorist plots in 2002, including the bombing of a French oil tanker.".to_string(),
"IBM puts grids to work at U.S. Open IBM will put a collection of its On Demand-related products and technologies to this test next week at the U.S. Open tennis championships, implementing a grid-based infrastructure capable of running multiple workloads including two not associated with the tournament.".to_string(),
"Jays power up to take finale Contrary to popular belief, the power never really \
snapped back at SkyDome on Sunday. The lights came on after an hour delay, but it \
took some extra time for the batting orders to provide some extra wattage."
.to_string(),
"Yemen Sentences 15 Militants on Terror Charges A court in Yemen has sentenced one \
man to death and 14 others to prison terms for a series of attacks and terrorist \
plots in 2002, including the bombing of a French oil tanker."
.to_string(),
"IBM puts grids to work at U.S. Open IBM will put a collection of its On \
Demand-related products and technologies to this test next week at the U.S. Open \
tennis championships, implementing a grid-based infrastructure capable of running \
multiple workloads including two not associated with the tournament."
.to_string(),
],
);
}

View File

@ -15,8 +15,14 @@ pub fn launch<B: AutodiffBackend>(device: B::Device) {
// Samples from the test dataset, but you are free to test with your own text.
vec![
" Magnus Eriksson is a Swedish former footballer who played as a forward.".to_string(),
"Crossbeam Systems is headquartered in Boxborough Massachusetts and has offices in Europe Latin America and Asia Pacific. Crossbeam Systems was acquired by Blue Coat Systems in December 2012 and the Crossbeam brand has been fully absorbed into Blue Coat.".to_string(),
" Zia is the sequel to the award-winning Island of the Blue Dolphins by Scott O'Dell. It was published in 1976 sixteen years after the publication of the first novel.".to_string(),
"Crossbeam Systems is headquartered in Boxborough Massachusetts and has offices in \
Europe Latin America and Asia Pacific. Crossbeam Systems was acquired by Blue Coat \
Systems in December 2012 and the Crossbeam brand has been fully absorbed into Blue \
Coat."
.to_string(),
" Zia is the sequel to the award-winning Island of the Blue Dolphins by Scott O'Dell. \
It was published in 1976 sixteen years after the publication of the first novel."
.to_string(),
],
);
}

View File

@ -72,6 +72,9 @@ pub fn infer<B: Backend, D: TextClassificationDataset + 'static>(
let class = D::class_name(class_index as usize); // Get class name
// Print sample text, predicted logits and predicted class
println!("\n=== Item {i} ===\n- Text: {text}\n- Logits: {logits}\n- Prediction: {class}\n================");
println!(
"\n=== Item {i} ===\n- Text: {text}\n- Logits: {logits}\n- Prediction: \
{class}\n================"
);
}
}

5
rustfmt.toml Normal file
View File

@ -0,0 +1,5 @@
max_width = 100
# uncomment and run `cargo +nightly fmt --all` to find and fix lines that are too long (and therefore break autoformatting)
# error_on_line_overflow = true
# format_strings = true