mirror of https://github.com/tracel-ai/burn.git
Add 1d and 2d modules for interpolate with scaling (also fix ONNX Resize op) (#2081)
* Add interpolate module * Update module.md * Add interpolate 1d and 2d modules * Consolidated InterpolateMode for 1d and 2d * Remove CoordinateTransformationMode * Add 1d tests for interpolate * Refactor and fixes of ONNX Resize OP * Fix clippy * Fix docs * Fix no_std
This commit is contained in:
parent
bc24bf3c14
commit
297173124f
|
@ -161,21 +161,23 @@ Burn comes with built-in modules that you can use to build your own modules.
|
|||
|
||||
### General
|
||||
|
||||
| Burn API | PyTorch Equivalent |
|
||||
| -------------- | --------------------------------------------- |
|
||||
| `BatchNorm` | `nn.BatchNorm1d`, `nn.BatchNorm2d` etc. |
|
||||
| `Dropout` | `nn.Dropout` |
|
||||
| `Embedding` | `nn.Embedding` |
|
||||
| `Gelu` | `nn.Gelu` |
|
||||
| `GroupNorm` | `nn.GroupNorm` |
|
||||
| `InstanceNorm` | `nn.InstanceNorm1d`, `nn.InstanceNorm2d` etc. |
|
||||
| `LayerNorm` | `nn.LayerNorm` |
|
||||
| `LeakyRelu` | `nn.LeakyReLU` |
|
||||
| `Linear` | `nn.Linear` |
|
||||
| `Prelu` | `nn.PReLu` |
|
||||
| `Relu` | `nn.ReLU` |
|
||||
| `RmsNorm` | _No direct equivalent_ |
|
||||
| `SwiGlu` | _No direct equivalent_ |
|
||||
| Burn API | PyTorch Equivalent |
|
||||
| --------------- | --------------------------------------------- |
|
||||
| `BatchNorm` | `nn.BatchNorm1d`, `nn.BatchNorm2d` etc. |
|
||||
| `Dropout` | `nn.Dropout` |
|
||||
| `Embedding` | `nn.Embedding` |
|
||||
| `Gelu` | `nn.Gelu` |
|
||||
| `GroupNorm` | `nn.GroupNorm` |
|
||||
| `InstanceNorm` | `nn.InstanceNorm1d`, `nn.InstanceNorm2d` etc. |
|
||||
| `LayerNorm` | `nn.LayerNorm` |
|
||||
| `LeakyRelu` | `nn.LeakyReLU` |
|
||||
| `Linear` | `nn.Linear` |
|
||||
| `Prelu` | `nn.PReLu` |
|
||||
| `Relu` | `nn.ReLU` |
|
||||
| `RmsNorm` | _No direct equivalent_ |
|
||||
| `SwiGlu` | _No direct equivalent_ |
|
||||
| `Interpolate1d` | _No direct equivalent_ |
|
||||
| `Interpolate2d` | _No direct equivalent_ |
|
||||
|
||||
### Convolutions
|
||||
|
||||
|
|
|
@ -0,0 +1,248 @@
|
|||
use alloc::format;
|
||||
|
||||
use burn_tensor::module::interpolate;
|
||||
|
||||
use crate as burn;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::module::{Content, DisplaySettings, Ignored, Module, ModuleDisplay};
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::ops::InterpolateOptions;
|
||||
use crate::tensor::Tensor;
|
||||
|
||||
use super::InterpolateMode;
|
||||
|
||||
/// Configuration for the 1D interpolation module.
|
||||
///
|
||||
/// This struct defines the configuration options for the 1D interpolation operation.
|
||||
/// It allows specifying the output size, scale factor, and interpolation mode.
|
||||
#[derive(Config, Debug)]
|
||||
pub struct Interpolate1dConfig {
|
||||
/// Output size of the interpolated tensor.
|
||||
/// If specified, this takes precedence over `scale_factor`.
|
||||
#[config(default = "None")]
|
||||
pub output_size: Option<usize>,
|
||||
|
||||
/// Scale factor for resizing the input tensor.
|
||||
/// This is used when `output_size` is not specified.
|
||||
#[config(default = "None")]
|
||||
pub scale_factor: Option<f32>,
|
||||
|
||||
/// Interpolation mode to use for resizing.
|
||||
/// Determines how the output values are calculated.
|
||||
#[config(default = "InterpolateMode::Nearest")]
|
||||
pub mode: InterpolateMode,
|
||||
}
|
||||
|
||||
/// Interpolate module for resizing 1D tensors with shape [N, C, L].
|
||||
///
|
||||
/// This struct represents a 1D interpolation module that can resize tensors
|
||||
/// using various interpolation methods. It provides flexibility in specifying
|
||||
/// either an output size or a scale factor for resizing, along with options
|
||||
/// for the interpolation mode.
|
||||
///
|
||||
/// The module can be used to upsample or downsample 1D tensors, preserving the
|
||||
/// number of channels and batch size while adjusting the length dimension.
|
||||
///
|
||||
/// The module can be created using the [Interpolate1dConfig] struct and the
|
||||
/// `init` method, which returns an instance of the [Interpolate1d] struct.
|
||||
#[derive(Module, Clone, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct Interpolate1d {
|
||||
/// Output size of the interpolated tensor
|
||||
pub output_size: Option<usize>,
|
||||
|
||||
/// Scale factor for resizing the input tensor
|
||||
pub scale_factor: Option<f32>,
|
||||
|
||||
/// Interpolation mode used for resizing
|
||||
pub mode: Ignored<InterpolateMode>,
|
||||
}
|
||||
|
||||
impl Interpolate1dConfig {
|
||||
/// Initialize the interpolation module
|
||||
pub fn init(self) -> Interpolate1d {
|
||||
Interpolate1d {
|
||||
output_size: self.output_size,
|
||||
scale_factor: self.scale_factor,
|
||||
mode: Ignored(self.mode),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Interpolate1d {
|
||||
/// Performs the forward pass of the 1D interpolation module
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `input` - Input tensor with shape [N, C, L]
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Resized tensor with shape [N, C, L'], where L' is determined by
|
||||
/// the output_size or scale_factor specified in the module configuration
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```ignore
|
||||
/// let input = Tensor::<Backend, 3>::random([1, 3, 64], Distribution::Uniform(0.0, 1.0), &device);
|
||||
/// let interpolate = Interpolate1dConfig::new()
|
||||
/// .with_output_size(Some(128))
|
||||
/// .init();
|
||||
/// let output = interpolate.forward(input);
|
||||
/// assert_eq!(output.dims(), [1, 3, 128]);
|
||||
/// ```
|
||||
pub fn forward<B: Backend>(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
|
||||
let output_size = calculate_output_size(input.dims(), self.output_size, self.scale_factor);
|
||||
|
||||
// Use the interpolate operation to resize the temporal input tensor
|
||||
// by adding a new dimension for the interpolation axis
|
||||
let input = input.unsqueeze_dim(2);
|
||||
|
||||
let result = interpolate(
|
||||
input,
|
||||
[1, output_size],
|
||||
InterpolateOptions::new(self.mode.0.clone().into()),
|
||||
);
|
||||
|
||||
result.squeeze_dims(&[2])
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate output size based on input dimensions, output size, and scale factor
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `input_dims` - Input dimensions of the tensor
|
||||
/// * `output_size` - Output size for the interpolated tensor
|
||||
/// * `scale_factor` - Scale factor for resizing the tensor
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Output size for the interpolated tensor
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if neither output_size nor scale_factor is provided
|
||||
/// or if the scale factor is too large
|
||||
fn calculate_output_size(
|
||||
input_dims: [usize; 3],
|
||||
output_size: Option<usize>,
|
||||
scale_factor: Option<f32>,
|
||||
) -> usize {
|
||||
match (output_size, scale_factor) {
|
||||
(Some(output_size), None) => {
|
||||
// Use provided
|
||||
output_size
|
||||
}
|
||||
(None, Some(scale_factor)) => {
|
||||
// Calculate output size based on scale factor
|
||||
let [_, _, l] = input_dims;
|
||||
|
||||
let new_dim = (l as f64) * (scale_factor as f64);
|
||||
|
||||
if new_dim > usize::MAX as f64 {
|
||||
panic!("Scale factor is too large");
|
||||
}
|
||||
|
||||
new_dim as usize
|
||||
}
|
||||
_ => panic!("Either output_size or scale_factor must be provided"),
|
||||
}
|
||||
}
|
||||
|
||||
impl ModuleDisplay for Interpolate1d {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
content
|
||||
.add("mode", &self.mode)
|
||||
.add("output_size", &format!("{:?}", self.output_size))
|
||||
.add("scale_factor", &self.scale_factor)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use burn_tensor::Distribution;
|
||||
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
#[test]
|
||||
fn test_calculate_output_size() {
|
||||
let input_dims = [1, 1, 4];
|
||||
|
||||
let output_size = calculate_output_size(input_dims, Some(2), None);
|
||||
assert_eq!(output_size, 2);
|
||||
|
||||
let output_size = calculate_output_size(input_dims, None, Some(2.0));
|
||||
assert_eq!(output_size, 8);
|
||||
|
||||
let output_size = calculate_output_size(input_dims, None, Some(0.5));
|
||||
assert_eq!(output_size, 2);
|
||||
|
||||
let output_size = calculate_output_size(input_dims, None, Some(1.5));
|
||||
assert_eq!(output_size, 6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Either output_size or scale_factor must be provided")]
|
||||
fn test_panic() {
|
||||
let input_dims = [1, 1, 4];
|
||||
calculate_output_size(input_dims, None, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Scale factor is too large")]
|
||||
fn test_large_scale_factor() {
|
||||
let input_dims = [1, 1, usize::MAX - 1];
|
||||
calculate_output_size(input_dims, None, Some(2.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_module() {
|
||||
let input = Tensor::<TestBackend, 3>::random(
|
||||
[2, 3, 4],
|
||||
Distribution::Uniform(0.0, 1.0),
|
||||
&Default::default(),
|
||||
);
|
||||
|
||||
// Test with output_size
|
||||
let config = Interpolate1dConfig::new().with_output_size(Some(8));
|
||||
let interpolate = config.init();
|
||||
let output = interpolate.forward(input.clone());
|
||||
assert_eq!(output.dims(), [2, 3, 8]);
|
||||
|
||||
// Test with scale_factor
|
||||
let config = Interpolate1dConfig::new().with_scale_factor(Some(0.5));
|
||||
let interpolate = config.init();
|
||||
let output = interpolate.forward(input.clone());
|
||||
assert_eq!(output.dims(), [2, 3, 2]);
|
||||
|
||||
// Test with different interpolation mode
|
||||
let config = Interpolate1dConfig::new()
|
||||
.with_output_size(Some(6))
|
||||
.with_mode(InterpolateMode::Linear);
|
||||
let interpolate = config.init();
|
||||
let output = interpolate.forward(input);
|
||||
assert_eq!(output.dims(), [2, 3, 6]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = Interpolate1dConfig::new().with_output_size(Some(20));
|
||||
let layer = config.init();
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{}", layer),
|
||||
"Interpolate1d {mode: Nearest, output_size: Some(20), \
|
||||
scale_factor: None}"
|
||||
);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,251 @@
|
|||
use alloc::format;
|
||||
|
||||
use burn_tensor::module::interpolate;
|
||||
|
||||
use crate as burn;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::module::{Content, DisplaySettings, Ignored, Module, ModuleDisplay};
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::ops::InterpolateOptions;
|
||||
use crate::tensor::Tensor;
|
||||
|
||||
use super::InterpolateMode;
|
||||
|
||||
/// Configuration for the 2D interpolation module.
|
||||
///
|
||||
/// This struct defines the configuration options for the 2D interpolation operation.
|
||||
/// It allows specifying the output size, scale factor, and interpolation mode.
|
||||
#[derive(Config, Debug)]
|
||||
pub struct Interpolate2dConfig {
|
||||
/// Output size of the interpolated tensor.
|
||||
/// If specified, this takes precedence over `scale_factor`.
|
||||
#[config(default = "None")]
|
||||
pub output_size: Option<[usize; 2]>,
|
||||
|
||||
/// Scale factor for resizing the input tensor.
|
||||
/// This is used when `output_size` is not specified.
|
||||
#[config(default = "None")]
|
||||
pub scale_factor: Option<[f32; 2]>,
|
||||
|
||||
/// Interpolation mode to use for resizing.
|
||||
/// Determines how the output values are calculated.
|
||||
#[config(default = "InterpolateMode::Nearest")]
|
||||
pub mode: InterpolateMode,
|
||||
}
|
||||
|
||||
/// Interpolate module for resizing tensors with shape [N, C, H, W].
|
||||
///
|
||||
/// This struct represents an interpolation module that can resize tensors
|
||||
/// using various interpolation methods. It provides flexibility in specifying
|
||||
/// either an output size or a scale factor for resizing, along with options
|
||||
/// for the interpolation mode.
|
||||
///
|
||||
/// The module can be used to upsample or downsample tensors, preserving the
|
||||
/// number of channels and batch size while adjusting the height and width
|
||||
/// dimensions.
|
||||
///
|
||||
/// The module can be created using the [Interpolate2dConfig] struct and the
|
||||
/// `init` method, which returns an instance of the [Interpolate2d] struct.
|
||||
#[derive(Module, Clone, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct Interpolate2d {
|
||||
/// Output size of the interpolated tensor
|
||||
pub output_size: Option<[usize; 2]>,
|
||||
|
||||
/// Scale factor for resizing the input tensor
|
||||
pub scale_factor: Option<[f32; 2]>,
|
||||
|
||||
/// Interpolation mode used for resizing
|
||||
pub mode: Ignored<InterpolateMode>,
|
||||
}
|
||||
|
||||
impl Interpolate2dConfig {
|
||||
/// Initialize the interpolation module
|
||||
pub fn init(self) -> Interpolate2d {
|
||||
Interpolate2d {
|
||||
output_size: self.output_size,
|
||||
scale_factor: self.scale_factor,
|
||||
mode: Ignored(self.mode),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl Interpolate2d {
|
||||
/// Performs the forward pass of the interpolation module
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `input` - Input tensor with shape [N, C, H, W]
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Resized tensor with shape [N, C, H', W'], where H' and W' are determined by
|
||||
/// the output_size or scale_factor specified in the module configuration
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```ignore
|
||||
/// let input = Tensor::<Backend, 2>::random([1, 3, 64, 64], Distribution::Uniform(0.0, 1.0), &device);
|
||||
/// let interpolate = Interpolate2dConfig::new()
|
||||
/// .with_output_size(Some([128, 128]))
|
||||
/// .init();
|
||||
/// let output = interpolate.forward(input);
|
||||
/// assert_eq!(output.dims(), [1, 3, 128, 128]);
|
||||
/// ```
|
||||
pub fn forward<B: Backend>(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let output_size = calculate_output_size(input.dims(), self.output_size, self.scale_factor);
|
||||
interpolate(
|
||||
input,
|
||||
output_size,
|
||||
InterpolateOptions::new(self.mode.0.clone().into()),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculates the output size for tensor interpolation.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `input_dims` - The dimensions of the input tensor [N, C, H, W].
|
||||
/// * `output_size` - Optional desired output size [H', W'].
|
||||
/// * `scale_factor` - Optional scale factor for height and width [scale_h, scale_w].
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tuple [H', W'] representing the calculated output size.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if neither `output_size` nor `scale_factor` is provided,
|
||||
/// or if the scale factor results in dimensions exceeding usize::MAX.
|
||||
fn calculate_output_size(
|
||||
input_dims: [usize; 4],
|
||||
output_size: Option<[usize; 2]>,
|
||||
scale_factor: Option<[f32; 2]>,
|
||||
) -> [usize; 2] {
|
||||
match (output_size, scale_factor) {
|
||||
(Some(output_size), None) => {
|
||||
// Use provided
|
||||
output_size
|
||||
}
|
||||
(None, Some(scale_factor)) => {
|
||||
// Calculate output size based on scale factor
|
||||
let [_, _, h, w] = input_dims;
|
||||
|
||||
let new_dim_h = (h as f64) * (scale_factor[0] as f64);
|
||||
|
||||
if new_dim_h > usize::MAX as f64 {
|
||||
panic!("Scale factor for height is too large");
|
||||
}
|
||||
|
||||
let new_dim_w = (w as f64) * (scale_factor[1] as f64);
|
||||
|
||||
if new_dim_w > usize::MAX as f64 {
|
||||
panic!("Scale factor for width is too large");
|
||||
}
|
||||
|
||||
[new_dim_h as usize, new_dim_w as usize]
|
||||
}
|
||||
_ => panic!("Either output_size or scale_factor must be provided"),
|
||||
}
|
||||
}
|
||||
|
||||
impl ModuleDisplay for Interpolate2d {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
content
|
||||
.add("mode", &self.mode)
|
||||
.add("output_size", &format!("{:?}", self.output_size))
|
||||
.add("scale_factor", &self.scale_factor)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use burn_tensor::Distribution;
|
||||
|
||||
use crate::TestBackend;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_calculate_output_size() {
|
||||
let input_dims = [1, 1, 4, 4];
|
||||
|
||||
let output_size = calculate_output_size(input_dims, Some([2, 2]), None);
|
||||
assert_eq!(output_size, [2, 2]);
|
||||
|
||||
let output_size = calculate_output_size(input_dims, None, Some([2.0, 2.0]));
|
||||
assert_eq!(output_size, [8, 8]);
|
||||
|
||||
let output_size = calculate_output_size([1, 1, 4, 4], None, Some([0.5, 0.5]));
|
||||
assert_eq!(output_size, [2, 2]);
|
||||
|
||||
let output_size = calculate_output_size([1, 1, 4, 4], None, Some([2.0, 1.5]));
|
||||
assert_eq!(output_size, [8, 6]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Either output_size or scale_factor must be provided")]
|
||||
fn test_missing_params() {
|
||||
calculate_output_size([1, 1, 4, 4], None, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Scale factor for height is too large")]
|
||||
fn test_infinite_height() {
|
||||
calculate_output_size([1, 1, usize::MAX - 1, 4], None, Some([2.0, 1.0]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Scale factor for width is too large")]
|
||||
fn test_infinite_width() {
|
||||
calculate_output_size([1, 1, 4, usize::MAX - 1], None, Some([1.0, 2.0]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_module() {
|
||||
let input = Tensor::<TestBackend, 4>::random(
|
||||
[2, 3, 4, 4],
|
||||
Distribution::Uniform(0.0, 1.0),
|
||||
&Default::default(),
|
||||
);
|
||||
|
||||
// Test with output_size
|
||||
let config = Interpolate2dConfig::new().with_output_size(Some([8, 8]));
|
||||
let interpolate = config.init();
|
||||
let output = interpolate.forward(input.clone());
|
||||
assert_eq!(output.dims(), [2, 3, 8, 8]);
|
||||
|
||||
// Test with scale_factor
|
||||
let config = Interpolate2dConfig::new().with_scale_factor(Some([0.5, 0.5]));
|
||||
let interpolate = config.init();
|
||||
let output = interpolate.forward(input.clone());
|
||||
assert_eq!(output.dims(), [2, 3, 2, 2]);
|
||||
|
||||
// Test with different interpolation mode
|
||||
let config = Interpolate2dConfig::new()
|
||||
.with_output_size(Some([6, 6]))
|
||||
.with_mode(InterpolateMode::Linear);
|
||||
let interpolate = config.init();
|
||||
let output = interpolate.forward(input);
|
||||
assert_eq!(output.dims(), [2, 3, 6, 6]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = Interpolate2dConfig::new().with_output_size(Some([20, 20]));
|
||||
let layer = config.init();
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{}", layer),
|
||||
"Interpolate2d {mode: Nearest, output_size: Some([20, 20]), \
|
||||
scale_factor: None}"
|
||||
);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,46 @@
|
|||
mod interpolate1d;
|
||||
mod interpolate2d;
|
||||
|
||||
pub use interpolate1d::*;
|
||||
pub use interpolate2d::*;
|
||||
|
||||
use crate::tensor::ops::InterpolateMode as OpsInterpolateMode;
|
||||
|
||||
/// Algorithm used for downsampling and upsampling
|
||||
///
|
||||
/// This enum defines different interpolation modes for resampling data.
|
||||
#[derive(new, Debug, Clone, serde::Deserialize, serde::Serialize)]
|
||||
pub enum InterpolateMode {
|
||||
/// Nearest-neighbor interpolation
|
||||
///
|
||||
/// This mode selects the value of the nearest sample point for each output pixel.
|
||||
/// It is applicable for both temporal and spatial data.
|
||||
Nearest,
|
||||
|
||||
/// Linear interpolation
|
||||
///
|
||||
/// This mode calculates the output value using linear
|
||||
/// interpolation between nearby sample points.
|
||||
///
|
||||
/// It is applicable for both temporal and spatial data.
|
||||
Linear,
|
||||
|
||||
/// Cubic interpolation
|
||||
///
|
||||
/// This mode uses cubic interpolation to calculate the output value
|
||||
/// based on surrounding sample points.
|
||||
///
|
||||
/// It is applicable for both temporal and spatial data and generally
|
||||
/// provides smoother results than linear interpolation.
|
||||
Cubic,
|
||||
}
|
||||
|
||||
impl From<InterpolateMode> for OpsInterpolateMode {
|
||||
fn from(mode: InterpolateMode) -> Self {
|
||||
match mode {
|
||||
InterpolateMode::Nearest => OpsInterpolateMode::Nearest,
|
||||
InterpolateMode::Linear => OpsInterpolateMode::Bilinear,
|
||||
InterpolateMode::Cubic => OpsInterpolateMode::Bicubic,
|
||||
}
|
||||
}
|
||||
}
|
|
@ -16,6 +16,9 @@ pub mod pool;
|
|||
/// Transformer module
|
||||
pub mod transformer;
|
||||
|
||||
/// Interpolate module
|
||||
pub mod interpolate;
|
||||
|
||||
mod dropout;
|
||||
mod embedding;
|
||||
mod gelu;
|
||||
|
|
|
@ -6,8 +6,8 @@ fn main() {
|
|||
|
||||
// Add onnx models.
|
||||
ModelGen::new()
|
||||
.input("tests/add/add_int.onnx")
|
||||
.input("tests/add/add.onnx")
|
||||
.input("tests/add/add_int.onnx")
|
||||
.input("tests/argmax/argmax.onnx")
|
||||
.input("tests/avg_pool1d/avg_pool1d.onnx")
|
||||
.input("tests/avg_pool2d/avg_pool2d.onnx")
|
||||
|
@ -16,9 +16,13 @@ fn main() {
|
|||
.input("tests/clip/clip_opset16.onnx")
|
||||
.input("tests/clip/clip_opset7.onnx")
|
||||
.input("tests/concat/concat.onnx")
|
||||
.input("tests/constant_of_shape/constant_of_shape.onnx")
|
||||
.input("tests/constant_of_shape/constant_of_shape_full_like.onnx")
|
||||
.input("tests/conv1d/conv1d.onnx")
|
||||
.input("tests/conv2d/conv2d.onnx")
|
||||
.input("tests/conv3d/conv3d.onnx")
|
||||
.input("tests/conv_transpose2d/conv_transpose2d.onnx")
|
||||
.input("tests/conv_transpose3d/conv_transpose3d.onnx")
|
||||
.input("tests/cos/cos.onnx")
|
||||
.input("tests/div/div.onnx")
|
||||
.input("tests/dropout/dropout_opset16.onnx")
|
||||
|
@ -26,70 +30,71 @@ fn main() {
|
|||
.input("tests/equal/equal.onnx")
|
||||
.input("tests/erf/erf.onnx")
|
||||
.input("tests/exp/exp.onnx")
|
||||
.input("tests/expand/expand.onnx")
|
||||
.input("tests/flatten/flatten.onnx")
|
||||
.input("tests/gather/gather.onnx")
|
||||
.input("tests/gather_elements/gather_elements.onnx")
|
||||
.input("tests/gelu/gelu.onnx")
|
||||
.input("tests/global_avr_pool/global_avr_pool.onnx")
|
||||
.input("tests/greater/greater.onnx")
|
||||
.input("tests/greater_or_equal/greater_or_equal.onnx")
|
||||
.input("tests/layer_norm/layer_norm.onnx")
|
||||
.input("tests/leaky_relu/leaky_relu.onnx")
|
||||
.input("tests/less/less.onnx")
|
||||
.input("tests/less_or_equal/less_or_equal.onnx")
|
||||
.input("tests/linear/linear.onnx")
|
||||
.input("tests/log_softmax/log_softmax.onnx")
|
||||
.input("tests/log/log.onnx")
|
||||
.input("tests/log_softmax/log_softmax.onnx")
|
||||
.input("tests/mask_where/mask_where.onnx")
|
||||
.input("tests/matmul/matmul.onnx")
|
||||
.input("tests/min/min.onnx")
|
||||
.input("tests/max/max.onnx")
|
||||
.input("tests/maxpool1d/maxpool1d.onnx")
|
||||
.input("tests/maxpool2d/maxpool2d.onnx")
|
||||
.input("tests/min/min.onnx")
|
||||
.input("tests/mul/mul.onnx")
|
||||
.input("tests/neg/neg.onnx")
|
||||
.input("tests/not/not.onnx")
|
||||
.input("tests/pad/pad.onnx")
|
||||
.input("tests/expand/expand.onnx")
|
||||
.input("tests/greater/greater.onnx")
|
||||
.input("tests/greater_or_equal/greater_or_equal.onnx")
|
||||
.input("tests/less/less.onnx")
|
||||
.input("tests/less_or_equal/less_or_equal.onnx")
|
||||
.input("tests/recip/recip.onnx")
|
||||
.input("tests/relu/relu.onnx")
|
||||
.input("tests/leaky_relu/leaky_relu.onnx")
|
||||
.input("tests/pow/pow.onnx")
|
||||
.input("tests/pow/pow_int.onnx")
|
||||
.input("tests/prelu/prelu.onnx")
|
||||
.input("tests/random_normal/random_normal.onnx")
|
||||
.input("tests/random_uniform/random_uniform.onnx")
|
||||
.input("tests/range/range.onnx")
|
||||
.input("tests/recip/recip.onnx")
|
||||
.input("tests/reduce_max/reduce_max.onnx")
|
||||
.input("tests/reduce_min/reduce_min.onnx")
|
||||
.input("tests/reduce_mean/reduce_mean.onnx")
|
||||
.input("tests/reduce_min/reduce_min.onnx")
|
||||
.input("tests/reduce_prod/reduce_prod.onnx")
|
||||
.input("tests/reduce_sum/reduce_sum_opset13.onnx")
|
||||
.input("tests/reduce_sum/reduce_sum_opset11.onnx")
|
||||
.input("tests/reduce_sum/reduce_sum_opset13.onnx")
|
||||
.input("tests/relu/relu.onnx")
|
||||
.input("tests/reshape/reshape.onnx")
|
||||
.input("tests/resize/resize.onnx")
|
||||
.input("tests/resize/resize_with_sizes.onnx")
|
||||
.input("tests/resize/resize_1d_linear_scale.onnx")
|
||||
.input("tests/resize/resize_1d_nearest_scale.onnx")
|
||||
.input("tests/resize/resize_2d_bicubic_scale.onnx")
|
||||
.input("tests/resize/resize_2d_bilinear_scale.onnx")
|
||||
.input("tests/resize/resize_2d_nearest_scale.onnx")
|
||||
.input("tests/shape/shape.onnx")
|
||||
.input("tests/sigmoid/sigmoid.onnx")
|
||||
.input("tests/sign/sign.onnx")
|
||||
.input("tests/sin/sin.onnx")
|
||||
.input("tests/slice/slice.onnx")
|
||||
.input("tests/softmax/softmax.onnx")
|
||||
.input("tests/sqrt/sqrt.onnx")
|
||||
.input("tests/sub/sub_int.onnx")
|
||||
.input("tests/squeeze/squeeze_multiple.onnx")
|
||||
.input("tests/squeeze/squeeze_opset13.onnx")
|
||||
.input("tests/squeeze/squeeze_opset16.onnx")
|
||||
.input("tests/sub/sub.onnx")
|
||||
.input("tests/tanh/tanh.onnx")
|
||||
.input("tests/transpose/transpose.onnx")
|
||||
.input("tests/conv_transpose2d/conv_transpose2d.onnx")
|
||||
.input("tests/conv_transpose3d/conv_transpose3d.onnx")
|
||||
.input("tests/pow/pow.onnx")
|
||||
.input("tests/pow/pow_int.onnx")
|
||||
.input("tests/slice/slice.onnx")
|
||||
.input("tests/sub/sub_int.onnx")
|
||||
.input("tests/sum/sum.onnx")
|
||||
.input("tests/sum/sum_int.onnx")
|
||||
.input("tests/tanh/tanh.onnx")
|
||||
.input("tests/transpose/transpose.onnx")
|
||||
.input("tests/unsqueeze/unsqueeze.onnx")
|
||||
.input("tests/unsqueeze/unsqueeze_opset16.onnx")
|
||||
.input("tests/unsqueeze/unsqueeze_opset11.onnx")
|
||||
.input("tests/mask_where/mask_where.onnx")
|
||||
.input("tests/squeeze/squeeze_opset16.onnx")
|
||||
.input("tests/squeeze/squeeze_opset13.onnx")
|
||||
.input("tests/squeeze/squeeze_multiple.onnx")
|
||||
.input("tests/random_uniform/random_uniform.onnx")
|
||||
.input("tests/random_normal/random_normal.onnx")
|
||||
.input("tests/constant_of_shape/constant_of_shape.onnx")
|
||||
.input("tests/constant_of_shape/constant_of_shape_full_like.onnx")
|
||||
.input("tests/range/range.onnx")
|
||||
.input("tests/unsqueeze/unsqueeze_opset16.onnx")
|
||||
.out_dir("model/")
|
||||
.run_from_script();
|
||||
|
||||
|
|
|
@ -15,19 +15,23 @@ macro_rules! include_models {
|
|||
|
||||
// ATTENTION: Modify this macro to include all models in the `model` directory.
|
||||
include_models!(
|
||||
add_int,
|
||||
add,
|
||||
add_int,
|
||||
argmax,
|
||||
avg_pool2d,
|
||||
avg_pool1d,
|
||||
avg_pool2d,
|
||||
batch_norm,
|
||||
cast,
|
||||
clip_opset16,
|
||||
clip_opset7,
|
||||
concat,
|
||||
constant_of_shape,
|
||||
constant_of_shape_full_like,
|
||||
conv1d,
|
||||
conv2d,
|
||||
conv3d,
|
||||
conv_transpose2d,
|
||||
conv_transpose3d,
|
||||
cos,
|
||||
div,
|
||||
dropout_opset16,
|
||||
|
@ -41,37 +45,46 @@ include_models!(
|
|||
gather_elements,
|
||||
gelu,
|
||||
global_avr_pool,
|
||||
greater,
|
||||
greater_or_equal,
|
||||
layer_norm,
|
||||
leaky_relu,
|
||||
less,
|
||||
less_or_equal,
|
||||
linear,
|
||||
log_softmax,
|
||||
log,
|
||||
log_softmax,
|
||||
mask_where,
|
||||
matmul,
|
||||
min,
|
||||
max,
|
||||
maxpool1d,
|
||||
maxpool2d,
|
||||
min,
|
||||
mul,
|
||||
neg,
|
||||
not,
|
||||
pad,
|
||||
greater,
|
||||
greater_or_equal,
|
||||
less,
|
||||
less_or_equal,
|
||||
pow,
|
||||
pow_int,
|
||||
prelu,
|
||||
random_normal,
|
||||
random_uniform,
|
||||
range,
|
||||
recip,
|
||||
reduce_max,
|
||||
reduce_min,
|
||||
reduce_mean,
|
||||
reduce_min,
|
||||
reduce_prod,
|
||||
reduce_sum_opset13,
|
||||
reduce_sum_opset11,
|
||||
reduce_sum_opset13,
|
||||
relu,
|
||||
reshape,
|
||||
resize,
|
||||
resize_with_sizes,
|
||||
resize_1d_linear_scale,
|
||||
resize_1d_nearest_scale,
|
||||
resize_2d_bicubic_scale,
|
||||
resize_2d_bilinear_scale,
|
||||
resize_2d_nearest_scale,
|
||||
shape,
|
||||
sigmoid,
|
||||
sign,
|
||||
|
@ -79,26 +92,18 @@ include_models!(
|
|||
slice,
|
||||
softmax,
|
||||
sqrt,
|
||||
sub_int,
|
||||
squeeze_multiple,
|
||||
squeeze_opset13,
|
||||
squeeze_opset16,
|
||||
sub,
|
||||
sub_int,
|
||||
sum,
|
||||
sum_int,
|
||||
tanh,
|
||||
transpose,
|
||||
conv_transpose2d,
|
||||
conv_transpose3d,
|
||||
pow,
|
||||
pow_int,
|
||||
unsqueeze,
|
||||
unsqueeze_opset16,
|
||||
unsqueeze_opset11,
|
||||
squeeze_opset16,
|
||||
squeeze_opset13,
|
||||
squeeze_multiple,
|
||||
random_uniform,
|
||||
random_normal,
|
||||
constant_of_shape,
|
||||
constant_of_shape_full_like
|
||||
unsqueeze_opset16
|
||||
);
|
||||
|
||||
#[cfg(test)]
|
||||
|
@ -865,10 +870,10 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn resize() {
|
||||
fn resize_with_sizes() {
|
||||
// Initialize the model without weights (because the exported file does not contain them)
|
||||
let device = Default::default();
|
||||
let model: resize::Model<Backend> = resize::Model::new(&device);
|
||||
let model: resize_with_sizes::Model<Backend> = resize_with_sizes::Model::new(&device);
|
||||
|
||||
// Run the model
|
||||
let input = Tensor::<Backend, 4>::from_floats(
|
||||
|
@ -880,14 +885,153 @@ mod tests {
|
|||
]]],
|
||||
&device,
|
||||
);
|
||||
let size = Tensor::<Backend, 1, Int>::from_ints([1, 1, 2, 3], &device);
|
||||
|
||||
let output = model.forward(input, size);
|
||||
// The sizes are [1, 1, 2, 3]
|
||||
let output = model.forward(input);
|
||||
let expected = TensorData::from([[[[0.0f32, 1.5, 3.0], [12.0, 13.5, 15.0]]]]);
|
||||
|
||||
output.to_data().assert_eq(&expected, true);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "https://github.com/tracel-ai/burn/issues/2080"]
|
||||
fn resize_with_scales_1d_linear() {
|
||||
// Initialize the model without weights (because the exported file does not contain them)
|
||||
let device = Default::default();
|
||||
let model: resize_1d_linear_scale::Model<Backend> =
|
||||
resize_1d_linear_scale::Model::new(&device);
|
||||
|
||||
// Run the model
|
||||
let input = Tensor::<Backend, 3>::from_floats(
|
||||
[[[1.5410, -0.2934, -2.1788, 0.5684, -1.0845, -1.3986]]],
|
||||
&device,
|
||||
);
|
||||
|
||||
// The scales are 1.5
|
||||
let output = model.forward(input);
|
||||
|
||||
let output_sum = output.sum().into_scalar();
|
||||
let expected_sum = -4.568_224; // from pytorch
|
||||
|
||||
assert!(expected_sum.approx_eq(output_sum, (1.0e-4, 2)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resize_with_scales_2d_bilinear() {
|
||||
// Initialize the model without weights (because the exported file does not contain them)
|
||||
let device = Default::default();
|
||||
let model: resize_2d_bilinear_scale::Model<Backend> =
|
||||
resize_2d_bilinear_scale::Model::new(&device);
|
||||
|
||||
// Run the model
|
||||
let input = Tensor::<Backend, 4>::from_floats(
|
||||
[[[
|
||||
[-1.1258, -1.1524, -0.2506, -0.4339, 0.8487, 0.6920],
|
||||
[-0.3160, -2.1152, 0.3223, -1.2633, 0.3500, 0.3081],
|
||||
[0.1198, 1.2377, 1.1168, -0.2473, -1.3527, -1.6959],
|
||||
[0.5667, 0.7935, 0.4397, 0.1124, 0.6408, 0.4412],
|
||||
[-0.2159, -0.7425, 0.5627, 0.2596, 0.5229, 2.3022],
|
||||
[-1.4689, -1.5867, 1.2032, 0.0845, -1.2001, -0.0048],
|
||||
]]],
|
||||
&device,
|
||||
);
|
||||
|
||||
// The scales are 1.5, 1.5
|
||||
let output = model.forward(input);
|
||||
|
||||
let output_sum = output.sum().into_scalar();
|
||||
let expected_sum = -3.401_126_6; // from pytorch
|
||||
|
||||
assert!(expected_sum.approx_eq(output_sum, (1.0e-4, 2)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resize_with_scales_2d_nearest() {
|
||||
// Initialize the model without weights (because the exported file does not contain them)
|
||||
let device = Default::default();
|
||||
let model: resize_2d_nearest_scale::Model<Backend> =
|
||||
resize_2d_nearest_scale::Model::<Backend>::new(&device);
|
||||
|
||||
// Run the model
|
||||
let input = Tensor::<Backend, 4>::from_floats(
|
||||
[[[
|
||||
[-1.1258, -1.1524, -0.2506, -0.4339, 0.8487, 0.6920],
|
||||
[-0.3160, -2.1152, 0.3223, -1.2633, 0.3500, 0.3081],
|
||||
[0.1198, 1.2377, 1.1168, -0.2473, -1.3527, -1.6959],
|
||||
[0.5667, 0.7935, 0.4397, 0.1124, 0.6408, 0.4412],
|
||||
[-0.2159, -0.7425, 0.5627, 0.2596, 0.5229, 2.3022],
|
||||
[-1.4689, -1.5867, 1.2032, 0.0845, -1.2001, -0.0048],
|
||||
]]],
|
||||
&device,
|
||||
);
|
||||
|
||||
// The scales are 1.5, 1.5
|
||||
let output = model.forward(input);
|
||||
|
||||
assert_eq!(output.dims(), [1, 1, 9, 9]);
|
||||
|
||||
let output_sum = output.sum().into_scalar();
|
||||
let expected_sum = -0.812_227_7; // from pytorch
|
||||
|
||||
assert!(expected_sum.approx_eq(output_sum, (1.0e-4, 2)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resize_with_scales_1d_nearest() {
|
||||
// Initialize the model without weights (because the exported file does not contain them)
|
||||
let device = Default::default();
|
||||
let model: resize_1d_nearest_scale::Model<Backend> =
|
||||
resize_1d_nearest_scale::Model::<Backend>::new(&device);
|
||||
|
||||
// Run the model
|
||||
let input = Tensor::<Backend, 3>::from_floats(
|
||||
[[[1.5410, -0.2934, -2.1788, 0.5684, -1.0845, -1.3986]]],
|
||||
&device,
|
||||
);
|
||||
|
||||
// The scales are 1.5, 1.5
|
||||
let output = model.forward(input);
|
||||
|
||||
assert_eq!(output.dims(), [1, 1, 9]);
|
||||
|
||||
let output_sum = output.sum().into_scalar();
|
||||
let expected_sum = -4.568_224; // from pytorch
|
||||
|
||||
assert!(expected_sum.approx_eq(output_sum, (1.0e-4, 2)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resize_with_scales_2d_bicubic() {
|
||||
// Initialize the model without weights (because the exported file does not contain them)
|
||||
let device = Default::default();
|
||||
let model: resize_2d_bicubic_scale::Model<Backend> =
|
||||
resize_2d_bicubic_scale::Model::<Backend>::new(&device);
|
||||
|
||||
// Run the model
|
||||
let input = Tensor::<Backend, 4>::from_floats(
|
||||
[[[
|
||||
[-1.1258, -1.1524, -0.2506, -0.4339, 0.8487, 0.6920],
|
||||
[-0.3160, -2.1152, 0.3223, -1.2633, 0.3500, 0.3081],
|
||||
[0.1198, 1.2377, 1.1168, -0.2473, -1.3527, -1.6959],
|
||||
[0.5667, 0.7935, 0.4397, 0.1124, 0.6408, 0.4412],
|
||||
[-0.2159, -0.7425, 0.5627, 0.2596, 0.5229, 2.3022],
|
||||
[-1.4689, -1.5867, 1.2032, 0.0845, -1.2001, -0.0048],
|
||||
]]],
|
||||
&device,
|
||||
);
|
||||
|
||||
// The scales are 1.5, 1.5
|
||||
let output = model.forward(input);
|
||||
|
||||
assert_eq!(output.dims(), [1, 1, 9, 9]);
|
||||
|
||||
let output_sum = output.sum().into_scalar();
|
||||
|
||||
let expected_sum = -3.515_921; // from pytorch
|
||||
|
||||
assert!(expected_sum.approx_eq(output_sum, (1.0e-3, 2)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn shape() {
|
||||
let device = Default::default();
|
||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,71 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.onnx
|
||||
|
||||
class InterpolateModel(nn.Module):
|
||||
def __init__(self, scale_factor=None, size=None, mode='nearest', align_corners=None):
|
||||
super(InterpolateModel, self).__init__()
|
||||
self.scale_factor = scale_factor
|
||||
self.size = size
|
||||
self.mode = mode
|
||||
self.align_corners = align_corners
|
||||
|
||||
def forward(self, x):
|
||||
return nn.functional.interpolate(x, scale_factor=self.scale_factor, size=self.size,
|
||||
mode=self.mode, align_corners=self.align_corners)
|
||||
|
||||
def export_interpolate_onnx(filename, batch_size=1, channels=1, height=6, width=6,
|
||||
scale_factor=None, size=None, mode='nearest', dim=2, align_corners=None):
|
||||
model = InterpolateModel(scale_factor, size, mode, align_corners)
|
||||
model.eval()
|
||||
|
||||
# Add seed for reproducibility
|
||||
torch.manual_seed(0)
|
||||
|
||||
# Create a dummy input
|
||||
if dim == 1:
|
||||
dummy_input = torch.randn(batch_size, channels, width)
|
||||
elif dim == 2:
|
||||
dummy_input = torch.randn(batch_size, channels, height, width)
|
||||
else:
|
||||
raise ValueError("Unsupported dimension. Use 1 for temporal or 2 for spatial.")
|
||||
|
||||
# Export the model
|
||||
torch.onnx.export(model, dummy_input, filename,
|
||||
input_names=['input'], output_names=['output'],
|
||||
dynamic_axes={'input': {0: 'batch_size'},
|
||||
'output': {0: 'batch_size'}},
|
||||
opset_version=17)
|
||||
|
||||
output = model(dummy_input)
|
||||
print(f"Input shape: {dummy_input.shape}")
|
||||
print(f"Output shape: {output.shape}")
|
||||
|
||||
# Print sum data
|
||||
print(f"Input sum: {dummy_input.sum()}")
|
||||
print(f"Output sum: {output.sum()}")
|
||||
|
||||
print(f"Input: {dummy_input}")
|
||||
print(f"Output: {output}")
|
||||
|
||||
print(f"Model exported to {filename}")
|
||||
|
||||
print()
|
||||
|
||||
# Usage examples:
|
||||
if __name__ == "__main__":
|
||||
|
||||
|
||||
# 1D (temporal) examples
|
||||
export_interpolate_onnx("resize_1d_nearest_scale.onnx", scale_factor=1.5, mode='nearest', dim=1)
|
||||
export_interpolate_onnx("resize_1d_linear_scale.onnx", scale_factor=1.5, mode='linear', dim=1, align_corners=True)
|
||||
|
||||
# Cubic interpolation is not supported for 1D tensors
|
||||
# export_interpolate_onnx("resize_1d_cubic_scale.onnx", scale_factor=1.5, mode='cubic', dim=1)
|
||||
|
||||
# 2D (spatial) examples
|
||||
export_interpolate_onnx("resize_2d_nearest_scale.onnx", scale_factor=1.5, mode='nearest', dim=2)
|
||||
export_interpolate_onnx("resize_2d_bilinear_scale.onnx", scale_factor=1.5, mode='bilinear', dim=2, align_corners=True)
|
||||
export_interpolate_onnx("resize_2d_bicubic_scale.onnx", scale_factor=1.5, mode='bicubic', dim=2, align_corners=True)
|
Binary file not shown.
16
crates/burn-import/onnx-tests/tests/resize/resize.py → crates/burn-import/onnx-tests/tests/resize/resize_with_sizes.py
Normal file → Executable file
16
crates/burn-import/onnx-tests/tests/resize/resize.py → crates/burn-import/onnx-tests/tests/resize/resize_with_sizes.py
Normal file → Executable file
|
@ -4,10 +4,19 @@
|
|||
|
||||
import onnx
|
||||
from onnx import helper, TensorProto
|
||||
import numpy as np
|
||||
|
||||
def main() -> None:
|
||||
input_tensor = helper.make_tensor_value_info("input_tensor", TensorProto.FLOAT, [1, 1, 4, 4])
|
||||
sizes_tensor = helper.make_tensor_value_info("sizes", TensorProto.INT64, [4])
|
||||
|
||||
# Create sizes as a constant tensor
|
||||
sizes = np.array([1, 1, 2, 3], dtype=np.int64)
|
||||
sizes_tensor = helper.make_tensor(
|
||||
name="sizes",
|
||||
data_type=TensorProto.INT64,
|
||||
dims=sizes.shape,
|
||||
vals=sizes.flatten().tolist(),
|
||||
)
|
||||
|
||||
resize_node = helper.make_node(
|
||||
"Resize",
|
||||
|
@ -20,15 +29,16 @@ def main() -> None:
|
|||
graph_def = helper.make_graph(
|
||||
nodes=[resize_node],
|
||||
name="ResizeGraph",
|
||||
inputs=[input_tensor, sizes_tensor],
|
||||
inputs=[input_tensor],
|
||||
outputs=[
|
||||
helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, 1, 2, 2])
|
||||
],
|
||||
initializer=[sizes_tensor],
|
||||
)
|
||||
|
||||
model_def = helper.make_model(graph_def, producer_name="resize")
|
||||
|
||||
onnx.save(model_def, "resize.onnx")
|
||||
onnx.save(model_def, "resize_with_sizes.onnx")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
|
@ -64,6 +64,13 @@ impl ToTokens for f64 {
|
|||
}
|
||||
}
|
||||
|
||||
/// Prettier output for `f32`
|
||||
impl ToTokens for f32 {
|
||||
fn to_tokens(&self) -> TokenStream {
|
||||
convert_primitive(self)
|
||||
}
|
||||
}
|
||||
|
||||
/// Padding configuration
|
||||
impl ToTokens for PaddingConfig1d {
|
||||
fn to_tokens(&self) -> TokenStream {
|
||||
|
|
|
@ -1,29 +1,17 @@
|
|||
use super::{Node, NodeCodegen};
|
||||
use crate::burn::{OtherType, Scope, TensorType, Type};
|
||||
use burn::module::Module;
|
||||
use crate::burn::{OtherType, Scope, TensorType, ToTokens, Type};
|
||||
use burn::record::PrecisionSettings;
|
||||
use proc_macro2::TokenStream;
|
||||
use quote::quote;
|
||||
|
||||
#[derive(Module, Debug, Clone)]
|
||||
pub enum ResizeMode {
|
||||
Nearest,
|
||||
Linear,
|
||||
Cubic,
|
||||
}
|
||||
|
||||
#[derive(new, Module, Debug, Clone)]
|
||||
pub struct ResizeOptions {
|
||||
pub mode: ResizeMode,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ResizeNode {
|
||||
pub field: OtherType,
|
||||
pub input: TensorType,
|
||||
pub output: TensorType,
|
||||
pub output_size: TensorType,
|
||||
pub config: ResizeOptions,
|
||||
mode: String,
|
||||
scales: Vec<f32>,
|
||||
sizes: Vec<usize>,
|
||||
}
|
||||
|
||||
impl ResizeNode {
|
||||
|
@ -31,20 +19,29 @@ impl ResizeNode {
|
|||
name: S,
|
||||
input: TensorType,
|
||||
output: TensorType,
|
||||
output_size: TensorType,
|
||||
config: ResizeOptions,
|
||||
mode: String,
|
||||
scales: Vec<f32>,
|
||||
sizes: Vec<usize>,
|
||||
) -> Self {
|
||||
let ty = if input.dim == 3 {
|
||||
quote! {
|
||||
Interpolate1d
|
||||
}
|
||||
} else if input.dim == 4 {
|
||||
quote! {
|
||||
Interpolate2d
|
||||
}
|
||||
} else {
|
||||
panic!("Unsupported input dimension for resize node");
|
||||
};
|
||||
|
||||
Self {
|
||||
field: OtherType::new(
|
||||
name,
|
||||
quote! {
|
||||
burn::module::Ignored<InterpolateOptions>
|
||||
},
|
||||
),
|
||||
field: OtherType::new(name, ty),
|
||||
input,
|
||||
output,
|
||||
output_size,
|
||||
config,
|
||||
mode,
|
||||
scales,
|
||||
sizes,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -55,10 +52,7 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for ResizeNode {
|
|||
}
|
||||
|
||||
fn input_types(&self) -> Vec<Type> {
|
||||
vec![
|
||||
Type::Tensor(self.input.clone()),
|
||||
Type::Tensor(self.output_size.clone()),
|
||||
]
|
||||
vec![Type::Tensor(self.input.clone())]
|
||||
}
|
||||
|
||||
fn field_type(&self) -> Option<Type> {
|
||||
|
@ -68,59 +62,96 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for ResizeNode {
|
|||
fn field_init(&self) -> Option<TokenStream> {
|
||||
let name = &self.field.name;
|
||||
|
||||
let mode = match self.config.mode {
|
||||
ResizeMode::Linear => quote! { InterpolateMode::Bilinear },
|
||||
ResizeMode::Nearest => quote! { InterpolateMode::Nearest },
|
||||
ResizeMode::Cubic => quote! { InterpolateMode::Bicubic },
|
||||
let mode = match self.mode.as_str() {
|
||||
"nearest" => quote! { InterpolateMode::Nearest },
|
||||
"linear" => quote! { InterpolateMode::Linear },
|
||||
"cubic" => quote! { InterpolateMode::Cubic },
|
||||
_ => panic!("Unsupported mode for resize node"),
|
||||
};
|
||||
|
||||
let tokens = quote! {
|
||||
let #name = InterpolateOptions {
|
||||
mode: #mode,
|
||||
let tokens = if self.input.dim == 3 {
|
||||
let size = if let Some(size) = self.sizes.first() {
|
||||
let size = size.to_tokens();
|
||||
quote! { Some(#size) }
|
||||
} else {
|
||||
quote! { None }
|
||||
};
|
||||
let #name = burn::module::Ignored(#name);
|
||||
|
||||
let scale_factor = if let Some(scale) = self.scales.first() {
|
||||
let scale = scale.to_tokens();
|
||||
quote! { Some(#scale) }
|
||||
} else {
|
||||
quote! { None }
|
||||
};
|
||||
|
||||
quote! {
|
||||
let #name = Interpolate1dConfig::new()
|
||||
.with_output_size(#size)
|
||||
.with_scale_factor(#scale_factor)
|
||||
.with_mode(#mode)
|
||||
.init();
|
||||
}
|
||||
} else if self.input.dim == 4 {
|
||||
let size = if self.sizes.len() == 2 {
|
||||
let h = self.sizes[0].to_tokens();
|
||||
let w = self.sizes[1].to_tokens();
|
||||
quote! { Some([#h, #w]) }
|
||||
} else {
|
||||
quote! { None }
|
||||
};
|
||||
|
||||
let scale_factor = if self.scales.len() == 2 {
|
||||
let h = self.scales[0].to_tokens();
|
||||
let w = self.scales[1].to_tokens();
|
||||
quote! { Some([#h, #w]) }
|
||||
} else {
|
||||
quote! { None }
|
||||
};
|
||||
|
||||
quote! {
|
||||
let #name = Interpolate2dConfig::new()
|
||||
.with_output_size(#size)
|
||||
.with_scale_factor(#scale_factor)
|
||||
.with_mode(#mode)
|
||||
.init();
|
||||
}
|
||||
} else {
|
||||
panic!("Unsupported input dimension for resize node");
|
||||
};
|
||||
|
||||
Some(tokens)
|
||||
}
|
||||
|
||||
fn register_imports(&self, imports: &mut crate::burn::BurnImports) {
|
||||
imports.register("burn::nn::interpolate::InterpolateMode");
|
||||
if self.input.dim == 3 {
|
||||
imports.register("burn::nn::interpolate::Interpolate1dConfig");
|
||||
imports.register("burn::nn::interpolate::Interpolate1d");
|
||||
} else if self.input.dim == 4 {
|
||||
imports.register("burn::nn::interpolate::Interpolate2dConfig");
|
||||
imports.register("burn::nn::interpolate::Interpolate2d");
|
||||
} else {
|
||||
panic!("Unsupported input dimension for resize node");
|
||||
}
|
||||
}
|
||||
|
||||
fn field_serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
|
||||
S::serialize_none(serializer)
|
||||
}
|
||||
|
||||
fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
|
||||
let input = scope.tensor_use_owned(&self.input, node_position);
|
||||
let output_size = scope.tensor_use_owned(&self.output_size, node_position);
|
||||
let output = &self.output.name;
|
||||
|
||||
let field = &self.field.name;
|
||||
|
||||
quote! {
|
||||
let output_size_data = #output_size.to_data();
|
||||
let mut output_size = [0usize; 2];
|
||||
|
||||
for (i, &x) in output_size_data.as_slice::<B::IntElem>().unwrap().iter().rev().take(2).rev().enumerate() {
|
||||
output_size[i] = x.elem::<i64>() as usize;
|
||||
}
|
||||
|
||||
let #output = interpolate(
|
||||
#input,
|
||||
output_size,
|
||||
self.#field.0.clone(),
|
||||
);
|
||||
let #output = self.#field.forward(#input);
|
||||
}
|
||||
}
|
||||
|
||||
fn into_node(self) -> Node<PS> {
|
||||
Node::Resize(self)
|
||||
}
|
||||
|
||||
fn register_imports(&self, imports: &mut crate::burn::BurnImports) {
|
||||
imports.register("burn::tensor::ElementConversion");
|
||||
imports.register("burn::tensor::module::interpolate");
|
||||
imports.register("burn::tensor::ops::InterpolateMode");
|
||||
imports.register("burn::tensor::ops::InterpolateOptions");
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
@ -135,47 +166,42 @@ mod tests {
|
|||
};
|
||||
|
||||
#[test]
|
||||
fn test_codegen_nodes() {
|
||||
fn test_codegen_nodes_2d() {
|
||||
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
|
||||
|
||||
graph.register(ResizeNode::new(
|
||||
"resize",
|
||||
TensorType::new_float("tensor1", 4),
|
||||
TensorType::new_float("tensor2", 4),
|
||||
TensorType::new_int("output_size", 1),
|
||||
ResizeOptions::new(ResizeMode::Linear),
|
||||
"nearest".to_string(),
|
||||
vec![0.5, 0.5],
|
||||
vec![],
|
||||
));
|
||||
|
||||
graph.register_input_output(
|
||||
vec!["tensor1".to_string(), "output_size".to_string()],
|
||||
vec!["tensor2".to_string()],
|
||||
);
|
||||
graph.register_input_output(vec!["tensor1".to_string()], vec!["tensor2".to_string()]);
|
||||
|
||||
let expected = quote! {
|
||||
use burn::tensor::module::interpolate;
|
||||
use burn::tensor::ops::InterpolateMode;
|
||||
use burn::tensor::ops::InterpolateOptions;
|
||||
use burn::tensor::ElementConversion;
|
||||
use burn::tensor::Int;
|
||||
use burn::nn::interpolate::Interpolate2d;
|
||||
use burn::nn::interpolate::Interpolate2dConfig;
|
||||
use burn::nn::interpolate::InterpolateMode;
|
||||
use burn::{
|
||||
module::Module,
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Model<B: Backend> {
|
||||
resize: burn::module::Ignored<InterpolateOptions>,
|
||||
resize: Interpolate2d,
|
||||
phantom: core::marker::PhantomData<B>,
|
||||
device: burn::module::Ignored<B::Device>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Model <B> {
|
||||
impl<B: Backend> Model<B> {
|
||||
#[allow(unused_variables)]
|
||||
pub fn new(device: &B::Device) -> Self {
|
||||
let resize = InterpolateOptions {
|
||||
mode: InterpolateMode::Bilinear,
|
||||
};
|
||||
let resize = burn::module::Ignored(resize);
|
||||
let resize = Interpolate2dConfig::new()
|
||||
.with_output_size(None)
|
||||
.with_scale_factor(Some([0.5, 0.5]))
|
||||
.with_mode(InterpolateMode::Nearest)
|
||||
.init();
|
||||
Self {
|
||||
resize,
|
||||
phantom: core::marker::PhantomData,
|
||||
|
@ -183,20 +209,62 @@ mod tests {
|
|||
}
|
||||
}
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
pub fn forward(
|
||||
&self,
|
||||
tensor1: Tensor<B, 4>,
|
||||
output_size: Tensor<B, 1, Int>
|
||||
) -> Tensor<B, 4> {
|
||||
let output_size_data = output_size.to_data();
|
||||
let mut output_size = [0usize; 2];
|
||||
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let tensor2 = self.resize.forward(tensor1);
|
||||
tensor2
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
for (i, &x) in output_size_data.as_slice::<B::IntElem>().unwrap().iter().rev().take(2).rev().enumerate() {
|
||||
output_size[i] = x.elem::<i64>() as usize;
|
||||
assert_tokens(graph.codegen(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_codegen_nodes_1d() {
|
||||
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
|
||||
|
||||
graph.register(ResizeNode::new(
|
||||
"resize",
|
||||
TensorType::new_float("tensor1", 3),
|
||||
TensorType::new_float("tensor2", 3),
|
||||
"cubic".to_string(),
|
||||
vec![],
|
||||
vec![20],
|
||||
));
|
||||
|
||||
graph.register_input_output(vec!["tensor1".to_string()], vec!["tensor2".to_string()]);
|
||||
|
||||
let expected = quote! {
|
||||
use burn::nn::interpolate::Interpolate1d;
|
||||
use burn::nn::interpolate::Interpolate1dConfig;
|
||||
use burn::nn::interpolate::InterpolateMode;
|
||||
use burn::{
|
||||
module::Module,
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Model<B: Backend> {
|
||||
resize: Interpolate1d,
|
||||
phantom: core::marker::PhantomData<B>,
|
||||
device: burn::module::Ignored<B::Device>,
|
||||
}
|
||||
impl<B: Backend> Model<B> {
|
||||
#[allow(unused_variables)]
|
||||
pub fn new(device: &B::Device) -> Self {
|
||||
let resize = Interpolate1dConfig::new()
|
||||
.with_output_size(Some(20))
|
||||
.with_scale_factor(None)
|
||||
.with_mode(InterpolateMode::Cubic)
|
||||
.init();
|
||||
Self {
|
||||
resize,
|
||||
phantom: core::marker::PhantomData,
|
||||
device: burn::module::Ignored(device.clone()),
|
||||
}
|
||||
|
||||
let tensor2 = interpolate(tensor1, output_size, self.resize.0.clone());
|
||||
|
||||
}
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
pub fn forward(&self, tensor1: Tensor<B, 3>) -> Tensor<B, 3> {
|
||||
let tensor2 = self.resize.forward(tensor1);
|
||||
tensor2
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,7 +7,7 @@ use burn::nn::{
|
|||
PaddingConfig2d, PaddingConfig3d,
|
||||
};
|
||||
|
||||
use crate::burn::node::{pad::PadConfig, resize::ResizeMode};
|
||||
use crate::burn::node::pad::PadConfig;
|
||||
use onnx_ir::ir::{ArgType, AttributeValue, Data, Node};
|
||||
|
||||
/// Create a Conv1dConfig from the attributes of the node
|
||||
|
@ -976,26 +976,132 @@ pub fn reshape_config(node: &Node) -> Vec<i64> {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn resize_config(node: &Node) -> ResizeMode {
|
||||
pub fn resize_config(node: &Node) -> (String, Vec<f32>, Vec<usize>) {
|
||||
let mut mode: String = "".to_string();
|
||||
|
||||
let mut scales: Vec<f32>;
|
||||
let mut sizes: Vec<usize>;
|
||||
|
||||
let input = if let ArgType::Tensor(tensor) = &node
|
||||
.inputs
|
||||
.first()
|
||||
.expect("Resize: Input tensor must be present")
|
||||
.ty
|
||||
{
|
||||
tensor
|
||||
} else {
|
||||
panic!("Resize: input must be a tensor")
|
||||
};
|
||||
|
||||
// Note: we are ignoring some attributes because results are approximately the same
|
||||
// and we are not supporting all the attributes of the Resize operator.
|
||||
// However, some attributes are important to be checked and we are checking
|
||||
// against the default values of the attributes.
|
||||
// TODO revisit this when we have more Resize operators in the model
|
||||
for (key, value) in node.attrs.iter() {
|
||||
match key.as_str() {
|
||||
"coordinate_transformation_mode" => {}
|
||||
"cubic_coeff_a" => {}
|
||||
"mode" => mode = value.clone().into_string(),
|
||||
"nearest_mode" => {}
|
||||
"antialias" => assert_eq!(
|
||||
value.clone().into_i32(),
|
||||
0,
|
||||
"Resize: antialias other than 0 is not supported"
|
||||
),
|
||||
"axes" => panic!("Resize: custom axes attribute is not supported"),
|
||||
"coordinate_transformation_mode" => {
|
||||
log::warn!("Resize: coordinate_transformation_mode is ignored")
|
||||
}
|
||||
|
||||
"cubic_coeff_a" => log::warn!("Resize: cubic_coeff_a is ignored"),
|
||||
"exclude_outside" => assert_eq!(
|
||||
value.clone().into_i32(),
|
||||
0,
|
||||
"Resize: exclude_outside other than 0 is not supported"
|
||||
),
|
||||
"extrapolation_value" => assert_eq!(
|
||||
value.clone().into_f32(),
|
||||
0.0,
|
||||
"Resize: extrapolation_value other than 0.0 is not supported"
|
||||
),
|
||||
"keep_aspect_ratio_policy" => {
|
||||
assert_eq!(
|
||||
value.clone().into_string().to_lowercase(),
|
||||
"stretch",
|
||||
"Resize: keep_aspect_ratio_policy other than 'stretch' is not supported"
|
||||
)
|
||||
}
|
||||
"mode" => mode = value.clone().into_string().to_lowercase(),
|
||||
"nearest_mode" => log::warn!("Resize: nearest_mode is ignored"),
|
||||
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
let mode = match mode.as_str() {
|
||||
"nearest" => ResizeMode::Nearest,
|
||||
"linear" => ResizeMode::Linear,
|
||||
"cubic" => ResizeMode::Cubic,
|
||||
_ => panic!("Resize: invalid mode string, must be 'nearest', 'linear', or 'cubic'"),
|
||||
};
|
||||
let roi: Vec<f32> = node
|
||||
.inputs
|
||||
.get(1)
|
||||
.map(|input| {
|
||||
if let Some(data) = &input.value {
|
||||
data.clone().into_f32s()
|
||||
} else {
|
||||
vec![]
|
||||
}
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
mode
|
||||
scales = node
|
||||
.inputs
|
||||
.get(2)
|
||||
.map(|input| {
|
||||
if let Some(data) = &input.value {
|
||||
data.clone().into_f32s()
|
||||
} else {
|
||||
vec![]
|
||||
}
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
sizes = node
|
||||
.inputs
|
||||
.get(3)
|
||||
.map(|input| {
|
||||
if let Some(data) = &input.value {
|
||||
data.clone()
|
||||
.into_i64s()
|
||||
.iter()
|
||||
.map(|&x| x as usize)
|
||||
.collect()
|
||||
} else {
|
||||
vec![]
|
||||
}
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
if mode.is_empty() {
|
||||
panic!("Resize: mode attribute is required")
|
||||
}
|
||||
|
||||
if !roi.is_empty() {
|
||||
panic!("Resize: roi input is not supported")
|
||||
}
|
||||
|
||||
if scales.is_empty() && sizes.is_empty() {
|
||||
panic!("Resize: either scales or sizes input is required")
|
||||
}
|
||||
|
||||
if !scales.is_empty() {
|
||||
assert!(scales.len() == input.dim);
|
||||
// ignore the fist two items from scales
|
||||
// because they are the batch and channel dimensions
|
||||
scales = scales.iter().skip(2).cloned().collect();
|
||||
}
|
||||
|
||||
if !sizes.is_empty() {
|
||||
assert!(sizes.len() == input.dim);
|
||||
// ignore the fist two items from sizes
|
||||
// because they are the batch and channel dimensions
|
||||
sizes = sizes.iter().skip(2).cloned().collect();
|
||||
}
|
||||
|
||||
(mode, scales, sizes)
|
||||
}
|
||||
|
||||
//Note this function should only execute if the second input is a constant
|
||||
|
|
|
@ -46,7 +46,7 @@ use crate::{
|
|||
random_uniform::RandomUniformNode,
|
||||
range::RangeNode,
|
||||
reshape::ReshapeNode,
|
||||
resize::{ResizeNode, ResizeOptions},
|
||||
resize::ResizeNode,
|
||||
slice::SliceNode,
|
||||
squeeze::SqueezeNode,
|
||||
sum::SumNode,
|
||||
|
@ -646,13 +646,12 @@ impl ParsedOnnxGraph {
|
|||
let name = &node.name;
|
||||
|
||||
let input = TensorType::from(&node.inputs[0]);
|
||||
let output_size = TensorType::from(&node.inputs[3]);
|
||||
|
||||
let output = TensorType::from(node.outputs.first().unwrap());
|
||||
|
||||
let mode = resize_config(&node);
|
||||
let (mode, scales, sizes) = resize_config(&node);
|
||||
|
||||
ResizeNode::new(name, input, output, output_size, ResizeOptions { mode })
|
||||
ResizeNode::new(name, input, output, mode, scales, sizes)
|
||||
}
|
||||
|
||||
fn min_conversion(node: Node) -> BinaryNode {
|
||||
|
|
|
@ -128,7 +128,7 @@ pub struct UnfoldOptions {
|
|||
}
|
||||
|
||||
/// Algorithm used for upsampling.
|
||||
#[derive(new, Debug, Clone)]
|
||||
#[derive(new, Debug, Clone, serde::Deserialize, serde::Serialize)]
|
||||
pub enum InterpolateMode {
|
||||
/// Nearest-neighbor interpolation.
|
||||
/// <https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation>
|
||||
|
|
|
@ -85,6 +85,47 @@ mod tests {
|
|||
]]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "https://github.com/tracel-ai/burn/issues/2080"]
|
||||
fn test_1d_bicubic() {
|
||||
// Initialize the model without weights (because the exported file does not contain them)
|
||||
let device = Default::default();
|
||||
|
||||
// Run the model
|
||||
let input = TestTensor::<3>::from_floats(
|
||||
[[[1.5410, -0.2934, -2.1788, 0.5684, -1.0845, -1.3986]]],
|
||||
&device,
|
||||
);
|
||||
|
||||
let input = input.unsqueeze_dim(2);
|
||||
|
||||
let output = interpolate(
|
||||
input,
|
||||
[1, 9],
|
||||
InterpolateOptions::new(InterpolateMode::Bicubic),
|
||||
);
|
||||
|
||||
assert_eq!(output.dims(), [1, 1, 1, 9]);
|
||||
|
||||
// assert output data does not contain NaN
|
||||
assert!(
|
||||
!output
|
||||
.clone()
|
||||
.to_data()
|
||||
.as_slice::<f32>()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.any(|&x| x.is_nan()),
|
||||
"interpolate output contains NaN"
|
||||
);
|
||||
|
||||
TestTensor::<4>::from([[[[
|
||||
1.541, 0.5747652, -1.010614, -2.197787, -0.8269969, 0.59609234, -0.5803058, -1.3792794,
|
||||
-1.3986,
|
||||
]]]])
|
||||
.to_data()
|
||||
.assert_approx_eq(&output.into_data(), 3);
|
||||
}
|
||||
struct InterpolateTestCase {
|
||||
batch_size: usize,
|
||||
channels: usize,
|
||||
|
|
|
@ -85,6 +85,55 @@ mod tests {
|
|||
]]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "https://github.com/tracel-ai/burn/issues/2080"]
|
||||
fn test_1d_bilinear() {
|
||||
// Initialize the model without weights (because the exported file does not contain them)
|
||||
let device = Default::default();
|
||||
|
||||
// Run the model
|
||||
let input = TestTensor::<3>::from_floats(
|
||||
[[[1.5410, -0.2934, -2.1788, 0.5684, -1.0845, -1.3986]]],
|
||||
&device,
|
||||
);
|
||||
|
||||
let input = input.unsqueeze_dim(2);
|
||||
|
||||
let output = interpolate(
|
||||
input,
|
||||
[1, 9],
|
||||
InterpolateOptions::new(InterpolateMode::Bilinear),
|
||||
);
|
||||
|
||||
assert_eq!(output.dims(), [1, 1, 1, 9]);
|
||||
|
||||
// assert output data does not contain NaN
|
||||
assert!(
|
||||
!output
|
||||
.clone()
|
||||
.to_data()
|
||||
.as_slice::<f32>()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.any(|&x| x.is_nan()),
|
||||
"interpolate output contains NaN"
|
||||
);
|
||||
|
||||
TestTensor::<4>::from([[[[
|
||||
1.541f32,
|
||||
0.39450002,
|
||||
-0.76475,
|
||||
-1.943125,
|
||||
-0.80520004,
|
||||
0.36178753,
|
||||
-0.671275,
|
||||
-1.2022874,
|
||||
-1.3986,
|
||||
]]]])
|
||||
.to_data()
|
||||
.assert_approx_eq(&output.into_data(), 3);
|
||||
}
|
||||
|
||||
struct InterpolateTestCase {
|
||||
batch_size: usize,
|
||||
channels: usize,
|
||||
|
|
|
@ -59,6 +59,45 @@ mod tests {
|
|||
]]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_1d_nearest() {
|
||||
// Initialize the model without weights (because the exported file does not contain them)
|
||||
let device = Default::default();
|
||||
|
||||
// Run the model
|
||||
let input = TestTensor::<3>::from_floats(
|
||||
[[[1.5410, -0.2934, -2.1788, 0.5684, -1.0845, -1.3986]]],
|
||||
&device,
|
||||
);
|
||||
|
||||
let input = input.unsqueeze_dim(2);
|
||||
|
||||
let output = interpolate(
|
||||
input,
|
||||
[1, 9],
|
||||
InterpolateOptions::new(InterpolateMode::Nearest),
|
||||
);
|
||||
assert_eq!(output.dims(), [1, 1, 1, 9]);
|
||||
|
||||
// assert output data does not contain NaN
|
||||
assert!(
|
||||
!output
|
||||
.clone()
|
||||
.to_data()
|
||||
.as_slice::<f32>()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.any(|&x| x.is_nan()),
|
||||
"interpolate output contains NaN"
|
||||
);
|
||||
|
||||
TestTensor::<4>::from([[[[
|
||||
1.541, 1.541, -0.2934, -2.1788, -2.1788, 0.5684, -1.0845, -1.0845, -1.3986,
|
||||
]]]])
|
||||
.to_data()
|
||||
.assert_approx_eq(&output.into_data(), 3);
|
||||
}
|
||||
|
||||
struct InterpolateTestCase {
|
||||
batch_size: usize,
|
||||
channels: usize,
|
||||
|
|
|
@ -62,7 +62,7 @@ pub fn dim_inference(node: &mut Node) {
|
|||
NodeType::ReduceSum => reduce_sum_update_outputs(node),
|
||||
NodeType::Relu => same_as_input(node),
|
||||
NodeType::Reshape => reshape_update_outputs(node),
|
||||
NodeType::Resize => resize_update_outputs(node),
|
||||
NodeType::Resize => same_as_input(node),
|
||||
NodeType::Shape => shape_update_outputs(node),
|
||||
NodeType::Sigmoid => same_as_input(node),
|
||||
NodeType::Sign => same_as_input(node),
|
||||
|
@ -318,33 +318,6 @@ fn reshape_update_outputs(node: &mut Node) {
|
|||
}
|
||||
}
|
||||
|
||||
fn resize_update_outputs(node: &mut Node) {
|
||||
let input = match &node.inputs[0].ty {
|
||||
ArgType::Tensor(tensor) => tensor.clone(),
|
||||
_ => panic!("Resize: invalid input type"),
|
||||
};
|
||||
|
||||
let output = match &node.outputs[0].ty {
|
||||
ArgType::Tensor(tensor) => tensor.clone(),
|
||||
_ => panic!("Resize: invalid output type"),
|
||||
};
|
||||
|
||||
let output_size = match &node.inputs[3].ty {
|
||||
ArgType::Tensor(output_size) => output_size.clone(),
|
||||
_ => panic!("Resize: invalid output_size type"),
|
||||
};
|
||||
|
||||
if output_size.dim != 1 {
|
||||
panic!("Resize: output_size must be 1D");
|
||||
}
|
||||
|
||||
node.outputs[0].ty = ArgType::Tensor(TensorType {
|
||||
dim: input.dim,
|
||||
shape: None, // shape is calculated at runtime
|
||||
..output
|
||||
});
|
||||
}
|
||||
|
||||
fn greater_update_outputs(node: &mut Node) {
|
||||
match &node.inputs[0].ty {
|
||||
ArgType::Tensor(tensor) => {
|
||||
|
@ -838,7 +811,7 @@ fn gather_update_outputs(node: &mut Node) {
|
|||
|
||||
let input_tensor = match &node.inputs[0].ty {
|
||||
ArgType::Tensor(tensor) => tensor,
|
||||
_ => panic!("Only tensor input is valid"),
|
||||
ty => panic!("Only tensor input is valid but received: {:?}", ty),
|
||||
};
|
||||
|
||||
let indices_tensor = match &node.inputs[1].ty {
|
||||
|
|
|
@ -523,44 +523,54 @@ impl Data {
|
|||
_ => self,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_f16(self) -> f16 {
|
||||
if let Data::Float16(elem) = self {
|
||||
elem
|
||||
} else {
|
||||
panic!("Expected Float16, got {:?}", self);
|
||||
match self {
|
||||
Data::Float16(elem) => elem,
|
||||
Data::Float32(elem) => f16::from_f32(elem),
|
||||
Data::Float64(elem) => f16::from_f64(elem),
|
||||
_ => panic!("Cannot convert {:?} to f16", self),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_f32(self) -> f32 {
|
||||
if let Data::Float32(elem) = self {
|
||||
elem
|
||||
} else {
|
||||
panic!("Expected Float32, got {:?}", self);
|
||||
match self {
|
||||
Data::Float16(elem) => elem.to_f32(),
|
||||
Data::Float32(elem) => elem,
|
||||
Data::Float64(elem) => elem as f32,
|
||||
Data::Int32(elem) => elem as f32,
|
||||
Data::Int64(elem) => elem as f32,
|
||||
_ => panic!("Cannot convert {:?} to f32", self),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_f64(self) -> f64 {
|
||||
if let Data::Float64(elem) = self {
|
||||
elem
|
||||
} else {
|
||||
panic!("Expected Float64, got {:?}", self);
|
||||
match self {
|
||||
Data::Float16(elem) => elem.to_f64(),
|
||||
Data::Float32(elem) => elem as f64,
|
||||
Data::Float64(elem) => elem,
|
||||
Data::Int32(elem) => elem as f64,
|
||||
Data::Int64(elem) => elem as f64,
|
||||
_ => panic!("Cannot convert {:?} to f64", self),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_i32(self) -> i32 {
|
||||
if let Data::Int32(elem) = self {
|
||||
elem
|
||||
} else {
|
||||
panic!("Expected Int32, got {:?}", self);
|
||||
match self {
|
||||
Data::Int32(elem) => elem,
|
||||
Data::Int64(elem) => elem as i32,
|
||||
Data::Float32(elem) => elem as i32,
|
||||
Data::Float64(elem) => elem as i32,
|
||||
_ => panic!("Cannot convert {:?} to i32", self),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_i64(self) -> i64 {
|
||||
if let Data::Int64(elem) = self {
|
||||
elem
|
||||
} else {
|
||||
panic!("Expected Int64, got {:?}", self);
|
||||
match self {
|
||||
Data::Int32(elem) => elem as i64,
|
||||
Data::Int64(elem) => elem,
|
||||
Data::Float32(elem) => elem as i64,
|
||||
Data::Float64(elem) => elem as i64,
|
||||
_ => panic!("Cannot convert {:?} to i64", self),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -581,42 +591,53 @@ impl Data {
|
|||
}
|
||||
|
||||
pub fn into_f16s(self) -> Vec<f16> {
|
||||
if let Data::Float16s(elem) = self {
|
||||
elem
|
||||
} else {
|
||||
panic!("Expected Float16s, got {:?}", self);
|
||||
match self {
|
||||
Data::Float16s(elem) => elem,
|
||||
Data::Float32s(elem) => elem.into_iter().map(f16::from_f32).collect(),
|
||||
Data::Float64s(elem) => elem.into_iter().map(f16::from_f64).collect(),
|
||||
_ => panic!("Cannot convert {:?} to Vec<f16>", self),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_f32s(self) -> Vec<f32> {
|
||||
if let Data::Float32s(elem) = self {
|
||||
elem
|
||||
} else {
|
||||
panic!("Expected Float32s, got {:?}", self);
|
||||
match self {
|
||||
Data::Float16s(elem) => elem.into_iter().map(|x| x.to_f32()).collect(),
|
||||
Data::Float32s(elem) => elem,
|
||||
Data::Float64s(elem) => elem.into_iter().map(|x| x as f32).collect(),
|
||||
Data::Int32s(elem) => elem.into_iter().map(|x| x as f32).collect(),
|
||||
Data::Int64s(elem) => elem.into_iter().map(|x| x as f32).collect(),
|
||||
_ => panic!("Cannot convert {:?} to Vec<f32>", self),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_f64s(self) -> Vec<f64> {
|
||||
if let Data::Float64s(elem) = self {
|
||||
elem
|
||||
} else {
|
||||
panic!("Expected Float64s, got {:?}", self);
|
||||
match self {
|
||||
Data::Float16s(elem) => elem.into_iter().map(|x| x.to_f64()).collect(),
|
||||
Data::Float32s(elem) => elem.into_iter().map(|x| x as f64).collect(),
|
||||
Data::Float64s(elem) => elem,
|
||||
Data::Int32s(elem) => elem.into_iter().map(|x| x as f64).collect(),
|
||||
Data::Int64s(elem) => elem.into_iter().map(|x| x as f64).collect(),
|
||||
_ => panic!("Cannot convert {:?} to Vec<f64>", self),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_i32s(self) -> Vec<i32> {
|
||||
if let Data::Int32s(elem) = self {
|
||||
elem
|
||||
} else {
|
||||
panic!("Expected Int32s, got {:?}", self);
|
||||
match self {
|
||||
Data::Int32s(elem) => elem,
|
||||
Data::Int64s(elem) => elem.into_iter().map(|x| x as i32).collect(),
|
||||
Data::Float32s(elem) => elem.into_iter().map(|x| x as i32).collect(),
|
||||
Data::Float64s(elem) => elem.into_iter().map(|x| x as i32).collect(),
|
||||
_ => panic!("Cannot convert {:?} to Vec<i32>", self),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_i64s(self) -> Vec<i64> {
|
||||
if let Data::Int64s(elem) = self {
|
||||
elem
|
||||
} else {
|
||||
panic!("Expected Int64s, got {:?}", self);
|
||||
match self {
|
||||
Data::Int32s(elem) => elem.into_iter().map(|x| x as i64).collect(),
|
||||
Data::Int64s(elem) => elem,
|
||||
Data::Float32s(elem) => elem.into_iter().map(|x| x as i64).collect(),
|
||||
Data::Float64s(elem) => elem.into_iter().map(|x| x as i64).collect(),
|
||||
_ => panic!("Cannot convert {:?} to Vec<i64>", self),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue