mirror of https://github.com/tracel-ai/burn.git
Doc fixes (#418)
This commit is contained in:
parent
73a88d8209
commit
fce45f51be
|
@ -3,7 +3,7 @@ use crate as burn;
|
||||||
use super::LRScheduler;
|
use super::LRScheduler;
|
||||||
use crate::{config::Config, LearningRate};
|
use crate::{config::Config, LearningRate};
|
||||||
|
|
||||||
/// Configuration to create a [noam](NoamScheduler) learning rate scheduler.
|
/// Configuration to create a [noam](NoamLRScheduler) learning rate scheduler.
|
||||||
#[derive(Config)]
|
#[derive(Config)]
|
||||||
pub struct NoamLRSchedulerConfig {
|
pub struct NoamLRSchedulerConfig {
|
||||||
/// The initial learning rate.
|
/// The initial learning rate.
|
||||||
|
@ -26,7 +26,7 @@ pub struct NoamLRScheduler {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl NoamLRSchedulerConfig {
|
impl NoamLRSchedulerConfig {
|
||||||
/// Initialize a new [noam](NoamScheduler) learning rate scheduler.
|
/// Initialize a new [noam](NoamLRScheduler) learning rate scheduler.
|
||||||
pub fn init(&self) -> NoamLRScheduler {
|
pub fn init(&self) -> NoamLRScheduler {
|
||||||
NoamLRScheduler {
|
NoamLRScheduler {
|
||||||
warmup_steps: self.warmup_steps as f64,
|
warmup_steps: self.warmup_steps as f64,
|
||||||
|
|
|
@ -64,7 +64,7 @@ macro_rules! module {
|
||||||
///
|
///
|
||||||
/// Modules should be created using the [derive](burn_derive::Module) attribute.
|
/// Modules should be created using the [derive](burn_derive::Module) attribute.
|
||||||
/// This will make your module trainable, savable and loadable via
|
/// This will make your module trainable, savable and loadable via
|
||||||
/// [state](Module::state) and [load](Module::load).
|
/// `state` and `load`.
|
||||||
///
|
///
|
||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
|
|
|
@ -59,7 +59,7 @@ pub enum Conv1dPaddingConfig {
|
||||||
/// - weight: Tensor of shape [channels_out, channels_in, kernel_size] initialized from a uniform
|
/// - weight: Tensor of shape [channels_out, channels_in, kernel_size] initialized from a uniform
|
||||||
/// distribution `U(-k, k)` where `k = sqrt(1 / channels_in * kernel_size)`
|
/// distribution `U(-k, k)` where `k = sqrt(1 / channels_in * kernel_size)`
|
||||||
///
|
///
|
||||||
/// - bias: Tensor of shape [channels_out], initialized from a uniform distribution `U(-k, k)`
|
/// - bias: Tensor of shape `[channels_out]`, initialized from a uniform distribution `U(-k, k)`
|
||||||
/// where `k = sqrt(1 / channels_in * kernel_size)`
|
/// where `k = sqrt(1 / channels_in * kernel_size)`
|
||||||
#[derive(Module, Debug)]
|
#[derive(Module, Debug)]
|
||||||
pub struct Conv1d<B: Backend> {
|
pub struct Conv1d<B: Backend> {
|
||||||
|
|
|
@ -54,10 +54,10 @@ pub enum Conv2dPaddingConfig {
|
||||||
///
|
///
|
||||||
/// # Params
|
/// # Params
|
||||||
///
|
///
|
||||||
/// - weight: Tensor of shape [channels_out, channels_in, kernel_size_1, kernel_size_2] initialized from a uniform
|
/// - weight: Tensor of shape `[channels_out, channels_in, kernel_size_1, kernel_size_2]` initialized from a uniform
|
||||||
/// distribution `U(-k, k)` where `k = sqrt(1 / channels_in * kernel_size_1 * kernel_size_2)`
|
/// distribution `U(-k, k)` where `k = sqrt(1 / channels_in * kernel_size_1 * kernel_size_2)`
|
||||||
///
|
///
|
||||||
/// - bias: Tensor of shape [channels_out], initialized from a uniform distribution `U(-k, k)`
|
/// - bias: Tensor of shape `[channels_out]`, initialized from a uniform distribution `U(-k, k)`
|
||||||
/// where `k = sqrt(1 / channels_in * kernel_size_1 * kernel_size_2)`
|
/// where `k = sqrt(1 / channels_in * kernel_size_1 * kernel_size_2)`
|
||||||
#[derive(Module, Debug)]
|
#[derive(Module, Debug)]
|
||||||
pub struct Conv2d<B: Backend> {
|
pub struct Conv2d<B: Backend> {
|
||||||
|
|
|
@ -50,8 +50,8 @@ impl Initializer {
|
||||||
/// # Params
|
/// # Params
|
||||||
///
|
///
|
||||||
/// - shape: Shape of the initiated tensor.
|
/// - shape: Shape of the initiated tensor.
|
||||||
/// - fan_in: Option<usize>, the fan in to use in initialization formula, if needed
|
/// - fan_in: `Option<usize>`, the fan in to use in initialization formula, if needed
|
||||||
/// - fan_out: Option<usize>, the fan out to use in initialization formula, if needed
|
/// - fan_out: `Option<usize>`, the fan out to use in initialization formula, if needed
|
||||||
pub fn init_with<B: Backend, const D: usize, S: Into<Shape<D>>>(
|
pub fn init_with<B: Backend, const D: usize, S: Into<Shape<D>>>(
|
||||||
&self,
|
&self,
|
||||||
shape: S,
|
shape: S,
|
||||||
|
|
|
@ -22,8 +22,8 @@ impl<B: Backend> CrossEntropyLoss<B> {
|
||||||
///
|
///
|
||||||
/// # Shapes
|
/// # Shapes
|
||||||
///
|
///
|
||||||
/// - logits: [batch_size, num_targets]
|
/// - logits: `[batch_size, num_targets]`
|
||||||
/// - targets: [batch_size]
|
/// - targets: `[batch_size]`
|
||||||
pub fn forward(&self, logits: Tensor<B, 2>, targets: Tensor<B, 1, Int>) -> Tensor<B, 1> {
|
pub fn forward(&self, logits: Tensor<B, 2>, targets: Tensor<B, 1, Int>) -> Tensor<B, 1> {
|
||||||
let [batch_size] = targets.dims();
|
let [batch_size] = targets.dims();
|
||||||
|
|
||||||
|
|
|
@ -77,7 +77,7 @@ impl LstmConfig {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Initialize a new [lstm](lstm) module with a [record](LstmRecord).
|
/// Initialize a new [lstm](Lstm) module with a [record](LstmRecord).
|
||||||
pub fn init_with<B: Backend>(&self, record: LstmRecord<B>) -> Lstm<B> {
|
pub fn init_with<B: Backend>(&self, record: LstmRecord<B>) -> Lstm<B> {
|
||||||
let linear_config = LinearConfig {
|
let linear_config = LinearConfig {
|
||||||
d_input: self.d_input,
|
d_input: self.d_input,
|
||||||
|
|
|
@ -24,7 +24,7 @@ pub struct SgdConfig {
|
||||||
|
|
||||||
/// Optimizer that implements stochastic gradient descent with momentum.
|
/// Optimizer that implements stochastic gradient descent with momentum.
|
||||||
///
|
///
|
||||||
/// Momentum is optional and can be [configured](SgdConfig::momentum).
|
/// The optimizer can be configured with [SgdConfig](SgdConfig).
|
||||||
pub struct Sgd<B: Backend> {
|
pub struct Sgd<B: Backend> {
|
||||||
momentum: Option<Momentum<B>>,
|
momentum: Option<Momentum<B>>,
|
||||||
weight_decay: Option<WeightDecay<B>>,
|
weight_decay: Option<WeightDecay<B>>,
|
||||||
|
|
|
@ -28,6 +28,6 @@ where
|
||||||
/// Change the device of the state.
|
/// Change the device of the state.
|
||||||
///
|
///
|
||||||
/// This function will be called accordindly to have the state on the same device as the
|
/// This function will be called accordindly to have the state on the same device as the
|
||||||
/// gradient and the tensor when the [step](SimpleModuleOptimizer::step) function is called.
|
/// gradient and the tensor when the [step](SimpleOptimizer::step) function is called.
|
||||||
fn to_device<const D: usize>(state: Self::State<D>, device: &B::Device) -> Self::State<D>;
|
fn to_device<const D: usize>(state: Self::State<D>, device: &B::Device) -> Self::State<D>;
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,11 +3,11 @@ pub use burn_derive::Record;
|
||||||
use super::PrecisionSettings;
|
use super::PrecisionSettings;
|
||||||
use serde::{de::DeserializeOwned, Serialize};
|
use serde::{de::DeserializeOwned, Serialize};
|
||||||
|
|
||||||
/// Trait to define a family of types which can be recorded using any [settings](RecordSettings).
|
/// Trait to define a family of types which can be recorded using any [settings](PrecisionSettings).
|
||||||
pub trait Record: Send + Sync {
|
pub trait Record: Send + Sync {
|
||||||
type Item<S: PrecisionSettings>: Serialize + DeserializeOwned;
|
type Item<S: PrecisionSettings>: Serialize + DeserializeOwned;
|
||||||
|
|
||||||
/// Convert the current record into the corresponding item that follows the given [settings](RecordSettings).
|
/// Convert the current record into the corresponding item that follows the given [settings](PrecisionSettings).
|
||||||
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S>;
|
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S>;
|
||||||
/// Convert the given item into a record.
|
/// Convert the given item into a record.
|
||||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self;
|
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self;
|
||||||
|
|
|
@ -21,7 +21,7 @@ pub trait Recorder: Send + Sync + core::default::Default + core::fmt::Debug + Cl
|
||||||
/// Arguments used to load recorded objects.
|
/// Arguments used to load recorded objects.
|
||||||
type LoadArgs: Clone;
|
type LoadArgs: Clone;
|
||||||
|
|
||||||
/// Record using the given [settings](RecordSettings).
|
/// Record an item with the given arguments.
|
||||||
fn record<R: Record>(
|
fn record<R: Record>(
|
||||||
&self,
|
&self,
|
||||||
record: R,
|
record: R,
|
||||||
|
|
|
@ -43,7 +43,7 @@ where
|
||||||
|
|
||||||
/// Create from a json rows file (one json per line).
|
/// Create from a json rows file (one json per line).
|
||||||
///
|
///
|
||||||
/// Supported field types: https://docs.rs/serde_json/latest/serde_json/value/enum.Value.html
|
/// [Supported field types](https://docs.rs/serde_json/latest/serde_json/value/enum.Value.html)
|
||||||
pub fn from_json_rows<P: AsRef<Path>>(path: P) -> Result<Self, std::io::Error> {
|
pub fn from_json_rows<P: AsRef<Path>>(path: P) -> Result<Self, std::io::Error> {
|
||||||
let file = File::open(path)?;
|
let file = File::open(path)?;
|
||||||
let reader = BufReader::new(file);
|
let reader = BufReader::new(file);
|
||||||
|
@ -65,7 +65,7 @@ where
|
||||||
///
|
///
|
||||||
/// The supported field types are: String, integer, float, and bool.
|
/// The supported field types are: String, integer, float, and bool.
|
||||||
///
|
///
|
||||||
/// See: https://docs.rs/csv/latest/csv/tutorial/index.html#reading-with-serde
|
/// See: [Reading with Serde](https://docs.rs/csv/latest/csv/tutorial/index.html#reading-with-serde)
|
||||||
pub fn from_csv<P: AsRef<Path>>(path: P) -> Result<Self, std::io::Error> {
|
pub fn from_csv<P: AsRef<Path>>(path: P) -> Result<Self, std::io::Error> {
|
||||||
let file = File::open(path)?;
|
let file = File::open(path)?;
|
||||||
let reader = BufReader::new(file);
|
let reader = BufReader::new(file);
|
||||||
|
|
|
@ -68,13 +68,13 @@ impl From<&'static str> for SqliteDatasetError {
|
||||||
/// can be in any order.
|
/// can be in any order.
|
||||||
///
|
///
|
||||||
/// For the supported field types, refer to:
|
/// For the supported field types, refer to:
|
||||||
/// - Serialization field types: https://docs.rs/serde_rusqlite/latest/serde_rusqlite
|
/// - [Serialization field types](https://docs.rs/serde_rusqlite/latest/serde_rusqlite)
|
||||||
/// - SQLite data types: https://www.sqlite.org/datatype3.html
|
/// - [SQLite data types](https://www.sqlite.org/datatype3.html)
|
||||||
///
|
///
|
||||||
/// 2. The fields in the `I` struct can be serialized into a single column `item` in the table. In this case, the table
|
/// 2. The fields in the `I` struct can be serialized into a single column `item` in the table. In this case, the table
|
||||||
/// should have a single column named `item` of type `BLOB`. This is useful when the `I` struct contains complex fields
|
/// should have a single column named `item` of type `BLOB`. This is useful when the `I` struct contains complex fields
|
||||||
/// that cannot be mapped to a SQLite type, such as nested structs, vectors, etc. The serialization is done using
|
/// that cannot be mapped to a SQLite type, such as nested structs, vectors, etc. The serialization is done using
|
||||||
/// MessagePack (https://msgpack.org/).
|
/// [MessagePack](https://msgpack.org/).
|
||||||
///
|
///
|
||||||
/// Note: The code automatically figures out which of the above two cases is applicable, and uses the appropriate
|
/// Note: The code automatically figures out which of the above two cases is applicable, and uses the appropriate
|
||||||
/// method to read the data from the table.
|
/// method to read the data from the table.
|
||||||
|
@ -490,7 +490,7 @@ where
|
||||||
|
|
||||||
/// Serializes and writes an item to the database. The item is written to the table for the
|
/// Serializes and writes an item to the database. The item is written to the table for the
|
||||||
/// specified split. If the table does not exist, it is created. If the table exists, the item
|
/// specified split. If the table does not exist, it is created. If the table exists, the item
|
||||||
/// is appended to the table. The serialization is done using the MessagePack (https://msgpack.org/)
|
/// is appended to the table. The serialization is done using the [MessagePack](https://msgpack.org/)
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
///
|
///
|
||||||
|
|
|
@ -81,7 +81,7 @@ impl HuggingfaceDatasetLoader {
|
||||||
|
|
||||||
/// Specify a huggingface token to download datasets behind authentication.
|
/// Specify a huggingface token to download datasets behind authentication.
|
||||||
///
|
///
|
||||||
/// You can get a token from https://huggingface.co/settings/tokens
|
/// You can get a token from [tokens settings](https://huggingface.co/settings/tokens)
|
||||||
pub fn with_huggingface_token(mut self, huggingface_token: &str) -> Self {
|
pub fn with_huggingface_token(mut self, huggingface_token: &str) -> Self {
|
||||||
self.huggingface_token = Some(huggingface_token.to_string());
|
self.huggingface_token = Some(huggingface_token.to_string());
|
||||||
self
|
self
|
||||||
|
|
|
@ -81,9 +81,9 @@ pub trait ModuleOps<B: Backend> {
|
||||||
///
|
///
|
||||||
/// # Shapes
|
/// # Shapes
|
||||||
///
|
///
|
||||||
/// x: [batch_size, channels_in, height, width],
|
/// x: `[batch_size, channels_in, height, width]`,
|
||||||
/// weight: [channels_out, channels_in, kernel_size_1, kernel_size_2],
|
/// weight: `[channels_out, channels_in, kernel_size_1, kernel_size_2]`,
|
||||||
/// bias: [channels_out],
|
/// bias: `[channels_out]`,
|
||||||
fn conv2d(
|
fn conv2d(
|
||||||
x: B::TensorPrimitive<4>,
|
x: B::TensorPrimitive<4>,
|
||||||
weight: B::TensorPrimitive<4>,
|
weight: B::TensorPrimitive<4>,
|
||||||
|
@ -94,9 +94,9 @@ pub trait ModuleOps<B: Backend> {
|
||||||
///
|
///
|
||||||
/// # Shapes
|
/// # Shapes
|
||||||
///
|
///
|
||||||
/// x: [batch_size, channels_in, height, width],
|
/// x: `[batch_size, channels_in, height, width]`,
|
||||||
/// weight: [channels_in, channels_out, kernel_size_1, kernel_size_2],
|
/// weight: `[channels_in, channels_out, kernel_size_1, kernel_size_2]`,
|
||||||
/// bias: [channels_out],
|
/// bias: `[channels_out]`,
|
||||||
fn conv_transpose2d(
|
fn conv_transpose2d(
|
||||||
x: B::TensorPrimitive<4>,
|
x: B::TensorPrimitive<4>,
|
||||||
weight: B::TensorPrimitive<4>,
|
weight: B::TensorPrimitive<4>,
|
||||||
|
@ -118,9 +118,9 @@ pub trait ModuleOps<B: Backend> {
|
||||||
///
|
///
|
||||||
/// # Shapes
|
/// # Shapes
|
||||||
///
|
///
|
||||||
/// x: [batch_size, channels_in, length],
|
/// x: `[batch_size, channels_in, length]`,
|
||||||
/// weight: [channels_out, channels_in, kernel_size],
|
/// weight: `[channels_out, channels_in, kernel_size]`,
|
||||||
/// bias: [channels_out],
|
/// bias: `[channels_out]`,
|
||||||
fn conv1d(
|
fn conv1d(
|
||||||
x: B::TensorPrimitive<3>,
|
x: B::TensorPrimitive<3>,
|
||||||
weight: B::TensorPrimitive<3>,
|
weight: B::TensorPrimitive<3>,
|
||||||
|
@ -133,9 +133,9 @@ pub trait ModuleOps<B: Backend> {
|
||||||
///
|
///
|
||||||
/// # Shapes
|
/// # Shapes
|
||||||
///
|
///
|
||||||
/// x: [batch_size, channels_in, length],
|
/// x: `[batch_size, channels_in, length]`,
|
||||||
/// weight: [channels_in, channels_out, length],
|
/// weight: `[channels_in, channels_out, length]`,
|
||||||
/// bias: [channels_out],
|
/// bias: `[channels_out]`,
|
||||||
fn conv_transpose1d(
|
fn conv_transpose1d(
|
||||||
x: B::TensorPrimitive<3>,
|
x: B::TensorPrimitive<3>,
|
||||||
weight: B::TensorPrimitive<3>,
|
weight: B::TensorPrimitive<3>,
|
||||||
|
|
|
@ -7,7 +7,7 @@ use burn_core::tensor::backend::ADBackend;
|
||||||
|
|
||||||
/// Learner struct encapsulating all components necessary to train a Neural Network model.
|
/// Learner struct encapsulating all components necessary to train a Neural Network model.
|
||||||
///
|
///
|
||||||
/// To create a learner, use the [builder](crate::train::LearnerBuilder) struct.
|
/// To create a learner, use the [builder](crate::learner::LearnerBuilder) struct.
|
||||||
pub struct Learner<B, M, O, LR, TO, VO>
|
pub struct Learner<B, M, O, LR, TO, VO>
|
||||||
where
|
where
|
||||||
B: ADBackend,
|
B: ADBackend,
|
||||||
|
|
|
@ -144,8 +144,8 @@ where
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Register a checkpointer that will save the [optimizer](crate::optim::Optimizer) and the
|
/// Register a checkpointer that will save the [optimizer](Optimizer) and the
|
||||||
/// [model](crate::module::Module) [states](crate::module::State).
|
/// [model](ADModule).
|
||||||
///
|
///
|
||||||
/// The number of checkpoints to be keep should be set to a minimum of two to be safe, since
|
/// The number of checkpoints to be keep should be set to a minimum of two to be safe, since
|
||||||
/// they are saved and deleted asynchronously and a crash during training might make a
|
/// they are saved and deleted asynchronously and a crash during training might make a
|
||||||
|
|
|
@ -44,7 +44,7 @@ pub trait Metric: Send + Sync {
|
||||||
/// Adaptor are used to transform types so that they can be used by metrics.
|
/// Adaptor are used to transform types so that they can be used by metrics.
|
||||||
///
|
///
|
||||||
/// This should be implemented by a model's output type for all [metric inputs](Metric::Input) that are
|
/// This should be implemented by a model's output type for all [metric inputs](Metric::Input) that are
|
||||||
/// registed with the [leaner buidler](burn::train::LearnerBuilder).
|
/// registed with the [leaner buidler](crate::learner::LearnerBuilder) .
|
||||||
pub trait Adaptor<T> {
|
pub trait Adaptor<T> {
|
||||||
/// Adapt the type to be passed to a [metric](Metric).
|
/// Adapt the type to be passed to a [metric](Metric).
|
||||||
fn adapt(&self) -> T;
|
fn adapt(&self) -> T;
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
use super::{MetricEntry, Numeric};
|
use super::{MetricEntry, Numeric};
|
||||||
|
|
||||||
/// Usefull utility to implement numeric [metrics](crate::train::metric::Metric).
|
/// Usefull utility to implement numeric metrics.
|
||||||
///
|
///
|
||||||
/// # Notes
|
/// # Notes
|
||||||
///
|
///
|
||||||
|
|
|
@ -3,24 +3,29 @@
|
||||||
/// Options are:
|
/// Options are:
|
||||||
/// - [Vulkan](Vulkan)
|
/// - [Vulkan](Vulkan)
|
||||||
/// - [Metal](Metal)
|
/// - [Metal](Metal)
|
||||||
/// - [OpenGL](OpenGL)
|
/// - [OpenGL](OpenGl)
|
||||||
/// - [DirectX 11](Dx11)
|
/// - [DirectX 11](Dx11)
|
||||||
/// - [DirectX 12](Dx12)
|
/// - [DirectX 12](Dx12)
|
||||||
/// - [WebGPU](WebGPU)
|
/// - [WebGpu](WebGpu)
|
||||||
pub trait GraphicsApi: Send + Sync + core::fmt::Debug + Default + Clone + 'static {
|
pub trait GraphicsApi: Send + Sync + core::fmt::Debug + Default + Clone + 'static {
|
||||||
fn backend() -> wgpu::Backend;
|
fn backend() -> wgpu::Backend;
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Default, Debug, Clone)]
|
#[derive(Default, Debug, Clone)]
|
||||||
pub struct Vulkan;
|
pub struct Vulkan;
|
||||||
|
|
||||||
#[derive(Default, Debug, Clone)]
|
#[derive(Default, Debug, Clone)]
|
||||||
pub struct Metal;
|
pub struct Metal;
|
||||||
|
|
||||||
#[derive(Default, Debug, Clone)]
|
#[derive(Default, Debug, Clone)]
|
||||||
pub struct OpenGl;
|
pub struct OpenGl;
|
||||||
|
|
||||||
#[derive(Default, Debug, Clone)]
|
#[derive(Default, Debug, Clone)]
|
||||||
pub struct Dx11;
|
pub struct Dx11;
|
||||||
|
|
||||||
#[derive(Default, Debug, Clone)]
|
#[derive(Default, Debug, Clone)]
|
||||||
pub struct Dx12;
|
pub struct Dx12;
|
||||||
|
|
||||||
#[derive(Default, Debug, Clone)]
|
#[derive(Default, Debug, Clone)]
|
||||||
pub struct WebGpu;
|
pub struct WebGpu;
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,7 @@ use burn::tensor::Tensor;
|
||||||
use wasm_bindgen::prelude::*;
|
use wasm_bindgen::prelude::*;
|
||||||
|
|
||||||
/// Mnist structure that corresponds to JavaScript class.
|
/// Mnist structure that corresponds to JavaScript class.
|
||||||
/// See: https://rustwasm.github.io/wasm-bindgen/contributing/design/exporting-rust-struct.html
|
/// See:[exporting-rust-struct](https://rustwasm.github.io/wasm-bindgen/contributing/design/exporting-rust-struct.html)
|
||||||
#[wasm_bindgen]
|
#[wasm_bindgen]
|
||||||
pub struct Mnist {
|
pub struct Mnist {
|
||||||
model: Model<Backend>,
|
model: Model<Backend>,
|
||||||
|
@ -35,8 +35,8 @@ impl Mnist {
|
||||||
/// * `input` - A f32 slice of input 28x28 image
|
/// * `input` - A f32 slice of input 28x28 image
|
||||||
///
|
///
|
||||||
/// See bindgen support types for passing and returning arrays:
|
/// See bindgen support types for passing and returning arrays:
|
||||||
/// * https://rustwasm.github.io/wasm-bindgen/reference/types/number-slices.html
|
/// * [number-slices](https://rustwasm.github.io/wasm-bindgen/reference/types/number-slices.html)
|
||||||
/// * https://rustwasm.github.io/wasm-bindgen/reference/types/boxed-number-slices.html
|
/// * [boxed-number-slices](https://rustwasm.github.io/wasm-bindgen/reference/types/boxed-number-slices.html)
|
||||||
///
|
///
|
||||||
pub fn inference(&self, input: &[f32]) -> Result<Box<[f32]>, String> {
|
pub fn inference(&self, input: &[f32]) -> Result<Box<[f32]>, String> {
|
||||||
// Reshape from the 1D array to 3d tensor [batch, height, width]
|
// Reshape from the 1D array to 3d tensor [batch, height, width]
|
||||||
|
|
|
@ -52,12 +52,18 @@ build_and_test_all_features() {
|
||||||
echo "Build with all defaults"
|
echo "Build with all defaults"
|
||||||
cargo build --all-features
|
cargo build --all-features
|
||||||
|
|
||||||
echo "Test with defaults"
|
echo "Test with all features"
|
||||||
cargo test --all-features
|
cargo test --all-features
|
||||||
|
|
||||||
|
echo "Check documentation with all features"
|
||||||
|
cargo doc --all-features
|
||||||
|
|
||||||
cd .. || exit
|
cd .. || exit
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Set RUSTDOCFLAGS to treat warnings as errors for the documentation build
|
||||||
|
export RUSTDOCFLAGS="-D warnings"
|
||||||
|
|
||||||
# Save the script start time
|
# Save the script start time
|
||||||
start_time=$(date +%s)
|
start_time=$(date +%s)
|
||||||
|
|
||||||
|
@ -65,11 +71,11 @@ start_time=$(date +%s)
|
||||||
rustup target add wasm32-unknown-unknown
|
rustup target add wasm32-unknown-unknown
|
||||||
rustup target add thumbv7m-none-eabi
|
rustup target add thumbv7m-none-eabi
|
||||||
|
|
||||||
# TODO decide if we should "cargo clean" here.
|
|
||||||
cargo build --workspace
|
cargo build --workspace
|
||||||
cargo test --workspace
|
cargo test --workspace
|
||||||
cargo fmt --check --all
|
cargo fmt --check --all
|
||||||
cargo clippy -- -D warnings
|
cargo clippy -- -D warnings
|
||||||
|
cargo doc --workspace
|
||||||
|
|
||||||
# no_std tests
|
# no_std tests
|
||||||
build_and_test_no_std "burn"
|
build_and_test_no_std "burn"
|
||||||
|
|
Loading…
Reference in New Issue