mirror of https://github.com/tracel-ai/burn.git
Typos (#608)
This commit is contained in:
parent
441a7011ce
commit
1d3bbaab13
|
@ -0,0 +1,13 @@
|
|||
name: Typos
|
||||
on: pull_request
|
||||
|
||||
jobs:
|
||||
run:
|
||||
name: Spell check with Typos
|
||||
runs-on: ubuntu-20.04
|
||||
steps:
|
||||
- name: Checkout Actions Repository
|
||||
uses: actions/checkout@8e5e7e5ab8b370d6c329ec480221332ada57f0ab
|
||||
|
||||
- name: Check spelling
|
||||
uses: crate-ci/typos@8a7996b4bcfa526668e5a9e7914330428897e205
|
|
@ -16,7 +16,7 @@ __Sections__
|
|||
Modules are a way of creating neural network structures that can be easily optimized, saved, and loaded with little to no boilerplate.
|
||||
Unlike other frameworks, a module does not force the declaration of the forward pass, leaving it up to the implementer to decide how it should be defined.
|
||||
Additionally, most modules are created using a (de)serializable configuration, which defines the structure of the module and its hyper-parameters.
|
||||
Parameters and hyper-parameters are not serialized into the same file and both are normaly necessary to load a module for inference.
|
||||
Parameters and hyper-parameters are not serialized into the same file and both are normally necessary to load a module for inference.
|
||||
|
||||
### Optimization
|
||||
|
||||
|
@ -71,7 +71,7 @@ When performing an optimization step, the adaptor handles the following:
|
|||
3. Makes sure that the gradient, the tensor, and the optimizer state associated with the current parameter are on the same device.
|
||||
The device can be different if the state is loaded from disk to restart training.
|
||||
4. Performs the simple optimizer step using the inner tensor since the operations done by the optimizer should not be tracked in the autodiff graph.
|
||||
5. Updates the state for the current parameter and returns the updated tensor, making sure it's properly registered into the autodiff graph if gradients are maked as required.
|
||||
5. Updates the state for the current parameter and returns the updated tensor, making sure it's properly registered into the autodiff graph if gradients are marked as required.
|
||||
|
||||
Note that a parameter can still be updated by another process, as is the case with running metrics used in batch norm.
|
||||
These tensors are still wrapped using the `Param` struct so that they are included in the module's state and given a proper parameter ID, but they are not registered in the autodiff graph.
|
||||
|
|
|
@ -295,7 +295,7 @@ Compile `scripts/publish.rs` using this command:
|
|||
rustc scripts/publish.rs --crate-type bin --out-dir scripts
|
||||
```
|
||||
|
||||
## Disclamer
|
||||
## Disclaimer
|
||||
|
||||
Burn is currently in active development, and there will be breaking changes. While any resulting
|
||||
issues are likely to be easy to fix, there are no guarantees at this stage.
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
[default]
|
||||
extend-ignore-identifiers-re = [
|
||||
"NdArray*",
|
||||
"ND"
|
||||
]
|
||||
|
||||
[files]
|
||||
extend-exclude = ["assets/ModuleSerialization.xml"]
|
|
@ -38,7 +38,7 @@ impl Graph {
|
|||
/// be shared with other graphs, therefore they are going to be cleared.
|
||||
///
|
||||
/// This is usefull, since the graph is supposed to be consumed only once for backprop, and
|
||||
/// keeping all the tensors alive for multiple backward call is a heavy waste of ressources.
|
||||
/// keeping all the tensors alive for multiple backward call is a heavy waste of resources.
|
||||
pub fn steps(self) -> NodeSteps {
|
||||
let mut map_drain = HashMap::new();
|
||||
self.execute_mut(|map| {
|
||||
|
|
|
@ -11,7 +11,7 @@ use std::marker::PhantomData;
|
|||
|
||||
/// Operation in preparation.
|
||||
///
|
||||
/// There are 3 diffent modes: 'Init', 'Tracked' and 'UnTracked'.
|
||||
/// There are 3 different modes: 'Init', 'Tracked' and 'UnTracked'.
|
||||
/// Each mode has its own set of functions to minimize cloning for unused backward states.
|
||||
#[derive(new)]
|
||||
pub struct OpsPrep<Backward, B, S, const D: usize, const N: usize, Mode = Init> {
|
||||
|
|
|
@ -4,7 +4,7 @@ mod tests {
|
|||
use burn_tensor::{Data, Distribution, Int, Shape, Tensor};
|
||||
|
||||
#[test]
|
||||
fn should_handle_broacast_during_backward() {
|
||||
fn should_handle_broadcast_during_backward() {
|
||||
let x: Tensor<TestADBackend, 2> = Tensor::from_data(
|
||||
Tensor::<TestADBackend, 1, Int>::arange(0..6)
|
||||
.into_data()
|
||||
|
|
|
@ -137,7 +137,7 @@ impl<const D: usize, B: Backend> RunningState<Tensor<B, D>> {
|
|||
/// # Note
|
||||
///
|
||||
/// Don't use this function after an update on the same thread where other threads might have to
|
||||
/// register their update before the actual synchonization needs to happen.
|
||||
/// register their update before the actual synchronization needs to happen.
|
||||
pub fn value_sync(&self) -> Tensor<B, D> {
|
||||
let thread_id = get_thread_current_id();
|
||||
let mut map = self.values.lock().unwrap();
|
||||
|
|
|
@ -35,7 +35,7 @@ pub struct GeneratePaddingMask<B: Backend> {
|
|||
pub fn generate_padding_mask<B: Backend>(
|
||||
pad_token: usize,
|
||||
tokens_list: Vec<Vec<usize>>,
|
||||
max_seq_lenght: Option<usize>,
|
||||
max_seq_length: Option<usize>,
|
||||
device: &B::Device,
|
||||
) -> GeneratePaddingMask<B> {
|
||||
let mut max_size = 0;
|
||||
|
@ -46,9 +46,9 @@ pub fn generate_padding_mask<B: Backend>(
|
|||
max_size = tokens.len();
|
||||
}
|
||||
|
||||
if let Some(max_seq_lenght) = max_seq_lenght {
|
||||
if tokens.len() >= max_seq_lenght {
|
||||
max_size = max_seq_lenght;
|
||||
if let Some(max_seq_length) = max_seq_length {
|
||||
if tokens.len() >= max_seq_length {
|
||||
max_size = max_seq_length;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -61,9 +61,9 @@ pub fn generate_padding_mask<B: Backend>(
|
|||
let mut seq_length = tokens.len();
|
||||
let mut tokens = tokens;
|
||||
|
||||
if let Some(max_seq_lenght) = max_seq_lenght {
|
||||
if seq_length > max_seq_lenght {
|
||||
seq_length = max_seq_lenght;
|
||||
if let Some(max_seq_length) = max_seq_length {
|
||||
if seq_length > max_seq_length {
|
||||
seq_length = max_seq_length;
|
||||
let _ = tokens.split_off(seq_length);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -392,7 +392,7 @@ mod tests {
|
|||
let output_1 = mha.forward(input_1);
|
||||
let output_2 = mha.forward(input_2);
|
||||
|
||||
// Check that the begginning of each tensor is the same
|
||||
// Check that the beginning of each tensor is the same
|
||||
output_1
|
||||
.context
|
||||
.slice([0..batch_size, 0..seq_length - num_padded, 0..d_model])
|
||||
|
|
|
@ -150,7 +150,7 @@ impl<B: Backend> Lstm<B> {
|
|||
let biased_ig_input_sum = self.gate_product(&input_t, &hidden_state, &self.input_gate);
|
||||
let add_values = activation::sigmoid(biased_ig_input_sum);
|
||||
|
||||
// o(utput)g(ate) tensors
|
||||
// o(output)g(ate) tensors
|
||||
let biased_og_input_sum = self.gate_product(&input_t, &hidden_state, &self.output_gate);
|
||||
let output_values = activation::sigmoid(biased_og_input_sum);
|
||||
|
||||
|
|
|
@ -34,7 +34,7 @@ pub struct JsonGzFileRecorder<S: PrecisionSettings> {
|
|||
_settings: PhantomData<S>,
|
||||
}
|
||||
|
||||
/// File recorder using [pretty json format](serde_json) for easy redability.
|
||||
/// File recorder using [pretty json format](serde_json) for easy readability.
|
||||
#[derive(new, Debug, Default, Clone)]
|
||||
pub struct PrettyJsonFileRecorder<S: PrecisionSettings> {
|
||||
_settings: PhantomData<S>,
|
||||
|
|
|
@ -7,7 +7,7 @@ use serde::{de::DeserializeOwned, Serialize};
|
|||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// This is especialy useful in no_std environment where weights are stored directly in
|
||||
/// This is especially useful in no_std environment where weights are stored directly in
|
||||
/// compiled binaries.
|
||||
pub trait BytesRecorder:
|
||||
Recorder<RecordArgs = (), RecordOutput = Vec<u8>, LoadArgs = Vec<u8>>
|
||||
|
|
|
@ -115,7 +115,7 @@ impl<B: Backend, const D: usize> Record for Param<Tensor<B, D>> {
|
|||
}
|
||||
}
|
||||
|
||||
// Type that can be serialized as is without any convertion.
|
||||
// Type that can be serialized as is without any conversion.
|
||||
macro_rules! primitive {
|
||||
($type:ty) => {
|
||||
impl Record for $type {
|
||||
|
|
|
@ -27,7 +27,7 @@ pub enum TestEnumConfig {
|
|||
#[cfg(feature = "std")]
|
||||
#[test]
|
||||
fn struct_config_should_impl_serde() {
|
||||
let config = TestStructConfig::new(2, 3.0, "Allo".to_string(), TestEmptyStructConfig::new());
|
||||
let config = TestStructConfig::new(2, 3.0, "Allow".to_string(), TestEmptyStructConfig::new());
|
||||
let file_path = "/tmp/test_struct_config.json";
|
||||
|
||||
config.save(file_path).unwrap();
|
||||
|
@ -38,13 +38,13 @@ fn struct_config_should_impl_serde() {
|
|||
|
||||
#[test]
|
||||
fn struct_config_should_impl_clone() {
|
||||
let config = TestStructConfig::new(2, 3.0, "Allo".to_string(), TestEmptyStructConfig::new());
|
||||
let config = TestStructConfig::new(2, 3.0, "Allow".to_string(), TestEmptyStructConfig::new());
|
||||
assert_eq!(config, config.clone());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn struct_config_should_impl_display() {
|
||||
let config = TestStructConfig::new(2, 3.0, "Allo".to_string(), TestEmptyStructConfig::new());
|
||||
let config = TestStructConfig::new(2, 3.0, "Allow".to_string(), TestEmptyStructConfig::new());
|
||||
assert_eq!(burn::config::config_to_json(&config), config.to_string());
|
||||
}
|
||||
|
||||
|
@ -75,7 +75,7 @@ fn enum_config_one_value_should_impl_serde() {
|
|||
#[cfg(feature = "std")]
|
||||
#[test]
|
||||
fn enum_config_multiple_values_should_impl_serde() {
|
||||
let config = TestEnumConfig::Multiple(42.0, "Allo".to_string());
|
||||
let config = TestEnumConfig::Multiple(42.0, "Allow".to_string());
|
||||
let file_path = "/tmp/test_enum_multiple_values_config.json";
|
||||
|
||||
config.save(file_path).unwrap();
|
||||
|
@ -86,19 +86,19 @@ fn enum_config_multiple_values_should_impl_serde() {
|
|||
|
||||
#[test]
|
||||
fn enum_config_should_impl_clone() {
|
||||
let config = TestEnumConfig::Multiple(42.0, "Allo".to_string());
|
||||
let config = TestEnumConfig::Multiple(42.0, "Allow".to_string());
|
||||
assert_eq!(config, config.clone());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn enum_config_should_impl_display() {
|
||||
let config = TestEnumConfig::Multiple(42.0, "Allo".to_string());
|
||||
let config = TestEnumConfig::Multiple(42.0, "Allow".to_string());
|
||||
assert_eq!(burn::config::config_to_json(&config), config.to_string());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn struct_config_can_load_binary() {
|
||||
let config = TestStructConfig::new(2, 3.0, "Allo".to_string(), TestEmptyStructConfig::new());
|
||||
let config = TestStructConfig::new(2, 3.0, "Allow".to_string(), TestEmptyStructConfig::new());
|
||||
|
||||
let binary = config_to_json(&config).as_bytes().to_vec();
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ fn speech_command() {
|
|||
let item = test.get(index).unwrap();
|
||||
|
||||
println!("Item: {:?}", item);
|
||||
println!("Item Lengh: {:?}", item.audio_samples.len());
|
||||
println!("Item Length: {:?}", item.audio_samples.len());
|
||||
println!("Label: {}", item.label.to_string());
|
||||
|
||||
assert_eq!(test.len(), 4890);
|
||||
|
|
|
@ -576,7 +576,7 @@ where
|
|||
///
|
||||
/// * A `Result` which is `Ok` if the table could be created, `Err` otherwise.
|
||||
///
|
||||
/// TODO (@antimora): add support creating a table with columns coresponding to the item fields
|
||||
/// TODO (@antimora): add support creating a table with columns corresponding to the item fields
|
||||
fn create_table(&self, split: &str) -> Result<()> {
|
||||
// Check if the split already exists
|
||||
if self.splits.read().unwrap().contains(split) {
|
||||
|
|
|
@ -50,7 +50,7 @@ def download_and_export(name: str, subset: str, db_file: str, token: str, cache_
|
|||
dataset = dataset.flatten()
|
||||
|
||||
# Rename columns to remove dots from the names
|
||||
dataset = rename_colums(dataset)
|
||||
dataset = rename_columns(dataset)
|
||||
|
||||
print(f"Saving dataset: {name} - {key}")
|
||||
print(f"Dataset features: {dataset.features}")
|
||||
|
@ -81,7 +81,7 @@ def disable_decoding(dataset):
|
|||
return dataset
|
||||
|
||||
|
||||
def rename_colums(dataset):
|
||||
def rename_columns(dataset):
|
||||
"""
|
||||
Rename columns to remove dots from the names. Dots appear in the column names because of the flattening.
|
||||
Dots are not allowed in column names in rust and sql (unless quoted). So we replace them with underscores.
|
||||
|
|
|
@ -85,7 +85,7 @@ pub(crate) fn parse_fields(ast: &syn::DeriveInput) -> Vec<Field> {
|
|||
}
|
||||
}
|
||||
syn::Data::Enum(_) => panic!("Only struct can be derived"),
|
||||
syn::Data::Union(_) => panic!("Only struct cna be derived"),
|
||||
syn::Data::Union(_) => panic!("Only struct can be derived"),
|
||||
};
|
||||
fields
|
||||
}
|
||||
|
|
|
@ -18,6 +18,6 @@ pub mod onnx;
|
|||
/// The module for generating the burn code.
|
||||
pub mod burn;
|
||||
|
||||
mod formater;
|
||||
mod formatter;
|
||||
mod logger;
|
||||
pub use formater::*;
|
||||
pub use formatter::*;
|
||||
|
|
|
@ -204,7 +204,7 @@ message NodeProto {
|
|||
repeated string output = 2; // namespace Value
|
||||
|
||||
// An optional identifier for this node in a graph.
|
||||
// This field MAY be absent in ths version of the IR.
|
||||
// This field MAY be absent in this version of the IR.
|
||||
string name = 3; // namespace Node
|
||||
|
||||
// The symbolic identifier of the Operator to execute.
|
||||
|
@ -403,7 +403,7 @@ message ModelProto {
|
|||
//
|
||||
// Name of the function "FunctionProto.name" should be unique within the domain "FunctionProto.domain".
|
||||
// In case of any conflicts the behavior (whether the model local functions are given higher priority,
|
||||
// or standard opserator sets are given higher priotity or this is treated as error) is defined by
|
||||
// or standard opserator sets are given higher priority or this is treated as error) is defined by
|
||||
// the runtimes.
|
||||
//
|
||||
// The operator sets imported by FunctionProto should be compatible with the ones
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// Orginally copied from the burn/examples/mnist package
|
||||
// Originally copied from the burn/examples/mnist package
|
||||
|
||||
use burn::{
|
||||
config::Config,
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// Orginally copied from the burn/examples/mnist package
|
||||
// Originally copied from the burn/examples/mnist package
|
||||
|
||||
use alloc::vec::Vec;
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// Orginally copied from the burn/examples/mnist package
|
||||
// Originally copied from the burn/examples/mnist package
|
||||
|
||||
use crate::{
|
||||
conv::{ConvBlock, ConvBlockConfig},
|
||||
|
@ -52,11 +52,11 @@ impl<B: Backend> Model<B> {
|
|||
}
|
||||
|
||||
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 2> {
|
||||
let [batch_size, heigth, width] = input.dims();
|
||||
let [batch_size, height, width] = input.dims();
|
||||
|
||||
let x = input.reshape([batch_size, 1, heigth, width]).detach();
|
||||
let x = input.reshape([batch_size, 1, height, width]).detach();
|
||||
let x = self.conv.forward(x);
|
||||
let x = x.reshape([batch_size, heigth * width]);
|
||||
let x = x.reshape([batch_size, height * width]);
|
||||
|
||||
let x = self.input.forward(x);
|
||||
let x = self.mlp.forward(x);
|
||||
|
|
|
@ -32,7 +32,7 @@ impl<E: tch::kind::Element, const D: usize> TchTensor<E, D> {
|
|||
|
||||
/// Create a tensor that was created from an operation executed on a parent tensor.
|
||||
///
|
||||
/// If the child tensor shared the same storage as its parent, it will be cloned, effectivly
|
||||
/// If the child tensor shared the same storage as its parent, it will be cloned, effectively
|
||||
/// tracking how much tensors point to the same memory space.
|
||||
pub fn from_existing(tensor: tch::Tensor, storage_parent: StorageRef) -> Self {
|
||||
let storage_child = tensor.data_ptr();
|
||||
|
|
|
@ -19,7 +19,7 @@ This library provides multiple tensor implementations hidden behind an easy to u
|
|||
|
||||
### Backends
|
||||
|
||||
For now, only two backends are implementated, but adding new ones should not be that hard.
|
||||
For now, only two backends are implemented, but adding new ones should not be that hard.
|
||||
|
||||
* [X] Pytorch using [tch-rs](https://github.com/LaurentMazare/tch-rs)
|
||||
* [X] 100% Rust backend using [ndarray](https://github.com/rust-ndarray/ndarray)
|
||||
|
@ -33,7 +33,7 @@ For now, only two backends are implementated, but adding new ones should not be
|
|||
Automatic differentiation is implemented as just another tensor backend without any global state.
|
||||
It's possible since we keep track of the order in which each operation as been executed and the tape is only created when calculating the gradients.
|
||||
To do so, each operation creates a new node which has a reference to its parent nodes.
|
||||
Therefore, creating the tape only requires a simple and efficent graph traversal algorithm.
|
||||
Therefore, creating the tape only requires a simple and efficient graph traversal algorithm.
|
||||
|
||||
```rust
|
||||
let x = ADTensor::from_tensor(x_ndarray);
|
||||
|
@ -62,5 +62,5 @@ This crate can be used without the standard library (`#![no_std]`) with `alloc`
|
|||
the default `std` feature.
|
||||
|
||||
* `std` - enables the standard library.
|
||||
* `burn-tensor-testgen` - enables test macros for genarating tensor tests.
|
||||
* `burn-tensor-testgen` - enables test macros for generating tensor tests.
|
||||
|
||||
|
|
|
@ -278,7 +278,7 @@ impl TensorCheck {
|
|||
.details(
|
||||
format!(
|
||||
"The ranges array must be smaller or equal to the tensor number of dimensions. \
|
||||
Tensor number of dimensions: {n_dims_tensor}, ranges array lenght {n_dims_ranges}."
|
||||
Tensor number of dimensions: {n_dims_tensor}, ranges array length {n_dims_ranges}."
|
||||
)));
|
||||
}
|
||||
|
||||
|
@ -334,7 +334,7 @@ impl TensorCheck {
|
|||
.details(
|
||||
format!(
|
||||
"The ranges array must be smaller or equal to the tensor number of dimensions. \
|
||||
Tensor number of dimensions: {D1}, ranges array lenght {D2}."
|
||||
Tensor number of dimensions: {D1}, ranges array length {D2}."
|
||||
)));
|
||||
}
|
||||
|
||||
|
@ -510,7 +510,7 @@ impl TensorCheck {
|
|||
}
|
||||
|
||||
/// The goal is to minimize the cost of checks when there are no error, but it's way less
|
||||
/// important when an error occured, crafting a comprehensive error message is more important
|
||||
/// important when an error occurred, crafting a comprehensive error message is more important
|
||||
/// than optimizing string manipulation.
|
||||
fn register(self, ops: &str, error: TensorError) -> Self {
|
||||
let errors = match self {
|
||||
|
@ -634,7 +634,7 @@ impl TensorError {
|
|||
}
|
||||
|
||||
/// We use a macro for all checks, since the panic message file and line number will match the
|
||||
/// function that does the check instead of a the generic error.rs crate private unreleated file
|
||||
/// function that does the check instead of a the generic error.rs crate private unrelated file
|
||||
/// and line number.
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! check {
|
||||
|
|
|
@ -250,7 +250,7 @@ where
|
|||
|
||||
/// Detach the current tensor from the autodiff graph.
|
||||
/// This function does nothing when autodiff is not enabled.
|
||||
/// This can be used in batchers or elsewere to ensure that previous operations are not
|
||||
/// This can be used in batchers or elsewhere to ensure that previous operations are not
|
||||
/// considered in the autodiff graph.
|
||||
pub fn detach(self) -> Self {
|
||||
Self::new(B::detach(self.primitive))
|
||||
|
|
|
@ -35,7 +35,7 @@ where
|
|||
Self::new(K::add_scalar(self.primitive, other))
|
||||
}
|
||||
|
||||
/// Applies element wise substraction operation.
|
||||
/// Applies element wise subtraction operation.
|
||||
///
|
||||
/// `y = x2 - x1`
|
||||
#[allow(clippy::should_implement_trait)]
|
||||
|
@ -44,7 +44,7 @@ where
|
|||
Self::new(K::sub(self.primitive, other.primitive))
|
||||
}
|
||||
|
||||
/// Applies element wise substraction operation with a scalar.
|
||||
/// Applies element wise subtraction operation with a scalar.
|
||||
///
|
||||
/// `y = x - s`
|
||||
pub fn sub_scalar<E: ElementConversion>(self, other: E) -> Self {
|
||||
|
@ -238,7 +238,7 @@ where
|
|||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// The index tensor shoud have the same shape as the original tensor except for the dim
|
||||
/// The index tensor should have the same shape as the original tensor except for the dim
|
||||
/// specified.
|
||||
pub fn gather(self, dim: usize, indices: Tensor<B, D, Int>) -> Self {
|
||||
check!(TensorCheck::gather::<D>(
|
||||
|
@ -250,7 +250,7 @@ where
|
|||
Self::new(K::gather(dim, self.primitive, indices))
|
||||
}
|
||||
|
||||
/// Assign the gathered elements corresponding to the given indices along the speficied dimension
|
||||
/// Assign the gathered elements corresponding to the given indices along the specified dimension
|
||||
/// from the value tensor to the original tensor using sum reduction.
|
||||
///
|
||||
/// Example using a 3D tensor:
|
||||
|
@ -261,7 +261,7 @@ where
|
|||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// The index tensor shoud have the same shape as the original tensor except for the speficied
|
||||
/// The index tensor should have the same shape as the original tensor except for the specified
|
||||
/// dimension. The value and index tensors should have the same shape.
|
||||
///
|
||||
/// Other references to the input tensor will not be modified by this operation.
|
||||
|
|
|
@ -326,7 +326,7 @@ impl<E: Into<f64> + Clone + core::fmt::Debug + PartialEq, const D: usize> Data<E
|
|||
let tolerance = libm::pow(0.1, precision as f64);
|
||||
|
||||
if err > tolerance {
|
||||
// Only print the first 5 differents values.
|
||||
// Only print the first 5 different values.
|
||||
if num_diff < max_num_diff {
|
||||
message += format!(
|
||||
"\n => Position {i}: {a} != {b} | difference {err} > tolerance {tolerance}"
|
||||
|
|
|
@ -195,8 +195,8 @@ where
|
|||
LR::Record: 'static,
|
||||
{
|
||||
self.init_logger();
|
||||
let callack = Box::new(self.dashboard);
|
||||
let callback = Box::new(AsyncTrainerCallback::new(callack));
|
||||
let callback = Box::new(self.dashboard);
|
||||
let callback = Box::new(AsyncTrainerCallback::new(callback));
|
||||
|
||||
let checkpointer_optimizer = match self.checkpointer_optimizer {
|
||||
Some(checkpointer) => {
|
||||
|
|
|
@ -54,7 +54,7 @@ pub trait Metric: Send + Sync {
|
|||
/// Adaptor are used to transform types so that they can be used by metrics.
|
||||
///
|
||||
/// This should be implemented by a model's output type for all [metric inputs](Metric::Input) that are
|
||||
/// registed with the [leaner buidler](crate::learner::LearnerBuilder) .
|
||||
/// registered with the [leaner buidler](crate::learner::LearnerBuilder) .
|
||||
pub trait Adaptor<T> {
|
||||
/// Adapt the type to be passed to a [metric](Metric).
|
||||
fn adapt(&self) -> T;
|
||||
|
|
|
@ -21,7 +21,7 @@ pub struct TrainingProgress {
|
|||
}
|
||||
|
||||
impl TrainingProgress {
|
||||
/// Creates a new empy training progress.
|
||||
/// Creates a new empty training progress.
|
||||
pub fn none() -> Self {
|
||||
Self {
|
||||
progress: Progress {
|
||||
|
|
|
@ -83,9 +83,9 @@ impl Context {
|
|||
/// # Notes
|
||||
///
|
||||
/// This function isn't safe, buffer can be mutated by the GPU. The users must ensure that a
|
||||
/// buffer can be mutated when lauching a compute shaders with write access to a buffer.
|
||||
/// buffer can be mutated when launching a compute shaders with write access to a buffer.
|
||||
///
|
||||
/// Buffer positions are used as bindings when lauching a compute kernel.
|
||||
/// Buffer positions are used as bindings when launching a compute kernel.
|
||||
pub fn execute(
|
||||
&self,
|
||||
work_group: WorkGroup,
|
||||
|
|
|
@ -22,7 +22,7 @@ pub trait ContextServer {
|
|||
fn start(device: Arc<wgpu::Device>, queue: wgpu::Queue) -> Self::Client;
|
||||
}
|
||||
|
||||
/// Context server where each operation is added in a synchonous maner.
|
||||
/// Context server where each operation is added in a synchronous maner.
|
||||
#[derive(Debug)]
|
||||
pub struct SyncContextServer {
|
||||
device: Arc<wgpu::Device>,
|
||||
|
@ -141,7 +141,7 @@ impl SyncContextServer {
|
|||
fn submit(&mut self) {
|
||||
assert!(
|
||||
self.tasks.is_empty(),
|
||||
"Tasks should be completed before submiting the current encoder."
|
||||
"Tasks should be completed before submitting the current encoder."
|
||||
);
|
||||
let mut new_encoder = self
|
||||
.device
|
||||
|
|
|
@ -10,7 +10,7 @@ use crate::{
|
|||
|
||||
use super::base::empty_from_context;
|
||||
|
||||
// Output of the pad_round function. Allows to know explicitly if early return occured
|
||||
// Output of the pad_round function. Allows to know explicitly if early return occurred
|
||||
pub(super) enum PaddingOutput<E: WgpuElement, const D: usize> {
|
||||
Padded(WgpuTensor<E, D>),
|
||||
Unchanged(WgpuTensor<E, D>),
|
||||
|
|
|
@ -3,7 +3,7 @@ use burn_tensor::Shape;
|
|||
use std::sync::Arc;
|
||||
use wgpu::Buffer;
|
||||
|
||||
/// Build basic info to lauch pool 2d kernels.
|
||||
/// Build basic info to launch pool 2d kernels.
|
||||
pub fn build_output_and_info_pool2d<E: WgpuElement>(
|
||||
x: &WgpuTensor<E, 4>,
|
||||
kernel_size: [usize; 2],
|
||||
|
|
|
@ -101,7 +101,7 @@
|
|||
isDrawingMode: true,
|
||||
});
|
||||
|
||||
const backgroundColor = "rgba(255, 255, 255, 255)"; // White with solid alha
|
||||
const backgroundColor = "rgba(255, 255, 255, 255)"; // White with solid alpha
|
||||
fabricCanvas.freeDrawingBrush.width = 25;
|
||||
fabricCanvas.backgroundColor = backgroundColor;
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
|
||||
# Openning index.html file directly by a browser does not work because of
|
||||
# Opening index.html file directly by a browser does not work because of
|
||||
# the security restrictions by the browser. Viewing the HTML file will fail with
|
||||
# this error message:
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#![allow(clippy::new_without_default)]
|
||||
|
||||
// Orginally copied from the burn/examples/mnist package
|
||||
// Originally copied from the burn/examples/mnist package
|
||||
|
||||
use burn::{
|
||||
module::Module,
|
||||
|
@ -48,15 +48,15 @@ impl<B: Backend> Model<B> {
|
|||
}
|
||||
|
||||
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 2> {
|
||||
let [batch_size, heigth, width] = input.dims();
|
||||
let [batch_size, height, width] = input.dims();
|
||||
|
||||
let x = input.reshape([batch_size, 1, heigth, width]).detach();
|
||||
let x = input.reshape([batch_size, 1, height, width]).detach();
|
||||
let x = self.conv1.forward(x);
|
||||
let x = self.conv2.forward(x);
|
||||
let x = self.conv3.forward(x);
|
||||
|
||||
let [batch_size, channels, heigth, width] = x.dims();
|
||||
let x = x.reshape([batch_size, channels * heigth * width]);
|
||||
let [batch_size, channels, height, width] = x.dims();
|
||||
let x = x.reshape([batch_size, channels * height * width]);
|
||||
|
||||
let x = self.dropout.forward(x);
|
||||
let x = self.fc1.forward(x);
|
||||
|
|
|
@ -50,15 +50,15 @@ impl<B: Backend> Model<B> {
|
|||
}
|
||||
|
||||
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 2> {
|
||||
let [batch_size, heigth, width] = input.dims();
|
||||
let [batch_size, height, width] = input.dims();
|
||||
|
||||
let x = input.reshape([batch_size, 1, heigth, width]).detach();
|
||||
let x = input.reshape([batch_size, 1, height, width]).detach();
|
||||
let x = self.conv1.forward(x);
|
||||
let x = self.conv2.forward(x);
|
||||
let x = self.conv3.forward(x);
|
||||
|
||||
let [batch_size, channels, heigth, width] = x.dims();
|
||||
let x = x.reshape([batch_size, channels * heigth * width]);
|
||||
let [batch_size, channels, height, width] = x.dims();
|
||||
let x = x.reshape([batch_size, channels * height * width]);
|
||||
|
||||
let x = self.dropout.forward(x);
|
||||
let x = self.fc1.forward(x);
|
||||
|
|
|
@ -24,7 +24,7 @@ pub fn run<B: Backend>() {
|
|||
//
|
||||
// mismatched types
|
||||
// expected reference `&NamedTensor<B, (Batch, DModel, _)>`
|
||||
// found reference `&NamedTensor<B, (Batch, SeqLenght, DModel)>`
|
||||
// found reference `&NamedTensor<B, (Batch, SeqLength, DModel)>`
|
||||
// let output = weights.matmul(&input);
|
||||
|
||||
let output = input.clone().matmul(weights.clone());
|
||||
|
@ -32,7 +32,7 @@ pub fn run<B: Backend>() {
|
|||
// Doesn't compile
|
||||
//
|
||||
// mismatched types
|
||||
// expected reference `&NamedTensor<B, (Batch, SeqLenght, DModel)>`
|
||||
// expected reference `&NamedTensor<B, (Batch, SeqLength, DModel)>`
|
||||
// found reference `&NamedTensor<B, (Batch, DModel, DModel)>`
|
||||
// let output = output.mul(&weights);
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ use std::sync::Arc;
|
|||
pub struct TextClassificationBatcher<B: Backend> {
|
||||
tokenizer: Arc<dyn Tokenizer>, // Tokenizer for converting text to token IDs
|
||||
device: B::Device, // Device on which to perform computation (e.g., CPU or CUDA device)
|
||||
max_seq_lenght: usize, // Maximum sequence length for tokenized text
|
||||
max_seq_length: usize, // Maximum sequence length for tokenized text
|
||||
}
|
||||
|
||||
/// Struct for training batch in text classification task
|
||||
|
@ -60,7 +60,7 @@ impl<B: Backend> Batcher<TextClassificationItem, TextClassificationTrainingBatch
|
|||
let mask = generate_padding_mask(
|
||||
self.tokenizer.pad_token(),
|
||||
tokens_list,
|
||||
Some(self.max_seq_lenght),
|
||||
Some(self.max_seq_length),
|
||||
&B::Device::default(),
|
||||
);
|
||||
|
||||
|
@ -90,7 +90,7 @@ impl<B: Backend> Batcher<String, TextClassificationInferenceBatch<B>>
|
|||
let mask = generate_padding_mask(
|
||||
self.tokenizer.pad_token(),
|
||||
tokens_list,
|
||||
Some(self.max_seq_lenght),
|
||||
Some(self.max_seq_length),
|
||||
&B::Device::default(),
|
||||
);
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ use std::sync::Arc;
|
|||
#[derive(new)]
|
||||
pub struct TextGenerationBatcher {
|
||||
tokenizer: Arc<dyn Tokenizer>,
|
||||
max_seq_lenght: usize,
|
||||
max_seq_length: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, new)]
|
||||
|
@ -36,7 +36,7 @@ impl<B: Backend> Batcher<TextGenerationItem, TextGenerationBatch<B>> for TextGen
|
|||
let mask = generate_padding_mask(
|
||||
self.tokenizer.pad_token(),
|
||||
tokens_list,
|
||||
Some(self.max_seq_lenght),
|
||||
Some(self.max_seq_length),
|
||||
&B::Device::default(),
|
||||
);
|
||||
|
||||
|
|
Loading…
Reference in New Issue