Add 1d and 2d modules for interpolate with scaling (also fix ONNX Resize op) (#2081)

* Add interpolate module

* Update module.md

* Add interpolate 1d and 2d modules

* Consolidated InterpolateMode for 1d and 2d

* Remove CoordinateTransformationMode

* Add 1d tests for interpolate

* Refactor and fixes of ONNX Resize OP

* Fix clippy

* Fix docs

* Fix no_std
This commit is contained in:
Dilshod Tadjibaev 2024-07-31 12:08:26 -05:00 committed by GitHub
parent bc24bf3c14
commit 297173124f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 1340 additions and 257 deletions

View File

@ -161,21 +161,23 @@ Burn comes with built-in modules that you can use to build your own modules.
### General
| Burn API | PyTorch Equivalent |
| -------------- | --------------------------------------------- |
| `BatchNorm` | `nn.BatchNorm1d`, `nn.BatchNorm2d` etc. |
| `Dropout` | `nn.Dropout` |
| `Embedding` | `nn.Embedding` |
| `Gelu` | `nn.Gelu` |
| `GroupNorm` | `nn.GroupNorm` |
| `InstanceNorm` | `nn.InstanceNorm1d`, `nn.InstanceNorm2d` etc. |
| `LayerNorm` | `nn.LayerNorm` |
| `LeakyRelu` | `nn.LeakyReLU` |
| `Linear` | `nn.Linear` |
| `Prelu` | `nn.PReLu` |
| `Relu` | `nn.ReLU` |
| `RmsNorm` | _No direct equivalent_ |
| `SwiGlu` | _No direct equivalent_ |
| Burn API | PyTorch Equivalent |
| --------------- | --------------------------------------------- |
| `BatchNorm` | `nn.BatchNorm1d`, `nn.BatchNorm2d` etc. |
| `Dropout` | `nn.Dropout` |
| `Embedding` | `nn.Embedding` |
| `Gelu` | `nn.Gelu` |
| `GroupNorm` | `nn.GroupNorm` |
| `InstanceNorm` | `nn.InstanceNorm1d`, `nn.InstanceNorm2d` etc. |
| `LayerNorm` | `nn.LayerNorm` |
| `LeakyRelu` | `nn.LeakyReLU` |
| `Linear` | `nn.Linear` |
| `Prelu` | `nn.PReLu` |
| `Relu` | `nn.ReLU` |
| `RmsNorm` | _No direct equivalent_ |
| `SwiGlu` | _No direct equivalent_ |
| `Interpolate1d` | _No direct equivalent_ |
| `Interpolate2d` | _No direct equivalent_ |
### Convolutions

View File

@ -0,0 +1,248 @@
use alloc::format;
use burn_tensor::module::interpolate;
use crate as burn;
use crate::config::Config;
use crate::module::{Content, DisplaySettings, Ignored, Module, ModuleDisplay};
use crate::tensor::backend::Backend;
use crate::tensor::ops::InterpolateOptions;
use crate::tensor::Tensor;
use super::InterpolateMode;
/// Configuration for the 1D interpolation module.
///
/// This struct defines the configuration options for the 1D interpolation operation.
/// It allows specifying the output size, scale factor, and interpolation mode.
#[derive(Config, Debug)]
pub struct Interpolate1dConfig {
/// Output size of the interpolated tensor.
/// If specified, this takes precedence over `scale_factor`.
#[config(default = "None")]
pub output_size: Option<usize>,
/// Scale factor for resizing the input tensor.
/// This is used when `output_size` is not specified.
#[config(default = "None")]
pub scale_factor: Option<f32>,
/// Interpolation mode to use for resizing.
/// Determines how the output values are calculated.
#[config(default = "InterpolateMode::Nearest")]
pub mode: InterpolateMode,
}
/// Interpolate module for resizing 1D tensors with shape [N, C, L].
///
/// This struct represents a 1D interpolation module that can resize tensors
/// using various interpolation methods. It provides flexibility in specifying
/// either an output size or a scale factor for resizing, along with options
/// for the interpolation mode.
///
/// The module can be used to upsample or downsample 1D tensors, preserving the
/// number of channels and batch size while adjusting the length dimension.
///
/// The module can be created using the [Interpolate1dConfig] struct and the
/// `init` method, which returns an instance of the [Interpolate1d] struct.
#[derive(Module, Clone, Debug)]
#[module(custom_display)]
pub struct Interpolate1d {
/// Output size of the interpolated tensor
pub output_size: Option<usize>,
/// Scale factor for resizing the input tensor
pub scale_factor: Option<f32>,
/// Interpolation mode used for resizing
pub mode: Ignored<InterpolateMode>,
}
impl Interpolate1dConfig {
/// Initialize the interpolation module
pub fn init(self) -> Interpolate1d {
Interpolate1d {
output_size: self.output_size,
scale_factor: self.scale_factor,
mode: Ignored(self.mode),
}
}
}
impl Interpolate1d {
/// Performs the forward pass of the 1D interpolation module
///
/// # Arguments
///
/// * `input` - Input tensor with shape [N, C, L]
///
/// # Returns
///
/// Resized tensor with shape [N, C, L'], where L' is determined by
/// the output_size or scale_factor specified in the module configuration
///
/// # Example
///
/// ```ignore
/// let input = Tensor::<Backend, 3>::random([1, 3, 64], Distribution::Uniform(0.0, 1.0), &device);
/// let interpolate = Interpolate1dConfig::new()
/// .with_output_size(Some(128))
/// .init();
/// let output = interpolate.forward(input);
/// assert_eq!(output.dims(), [1, 3, 128]);
/// ```
pub fn forward<B: Backend>(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
let output_size = calculate_output_size(input.dims(), self.output_size, self.scale_factor);
// Use the interpolate operation to resize the temporal input tensor
// by adding a new dimension for the interpolation axis
let input = input.unsqueeze_dim(2);
let result = interpolate(
input,
[1, output_size],
InterpolateOptions::new(self.mode.0.clone().into()),
);
result.squeeze_dims(&[2])
}
}
/// Calculate output size based on input dimensions, output size, and scale factor
///
/// # Arguments
///
/// * `input_dims` - Input dimensions of the tensor
/// * `output_size` - Output size for the interpolated tensor
/// * `scale_factor` - Scale factor for resizing the tensor
///
/// # Returns
///
/// Output size for the interpolated tensor
///
/// # Panics
///
/// Panics if neither output_size nor scale_factor is provided
/// or if the scale factor is too large
fn calculate_output_size(
input_dims: [usize; 3],
output_size: Option<usize>,
scale_factor: Option<f32>,
) -> usize {
match (output_size, scale_factor) {
(Some(output_size), None) => {
// Use provided
output_size
}
(None, Some(scale_factor)) => {
// Calculate output size based on scale factor
let [_, _, l] = input_dims;
let new_dim = (l as f64) * (scale_factor as f64);
if new_dim > usize::MAX as f64 {
panic!("Scale factor is too large");
}
new_dim as usize
}
_ => panic!("Either output_size or scale_factor must be provided"),
}
}
impl ModuleDisplay for Interpolate1d {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content
.add("mode", &self.mode)
.add("output_size", &format!("{:?}", self.output_size))
.add("scale_factor", &self.scale_factor)
.optional()
}
}
#[cfg(test)]
mod tests {
use burn_tensor::Distribution;
use super::*;
use crate::TestBackend;
#[test]
fn test_calculate_output_size() {
let input_dims = [1, 1, 4];
let output_size = calculate_output_size(input_dims, Some(2), None);
assert_eq!(output_size, 2);
let output_size = calculate_output_size(input_dims, None, Some(2.0));
assert_eq!(output_size, 8);
let output_size = calculate_output_size(input_dims, None, Some(0.5));
assert_eq!(output_size, 2);
let output_size = calculate_output_size(input_dims, None, Some(1.5));
assert_eq!(output_size, 6);
}
#[test]
#[should_panic(expected = "Either output_size or scale_factor must be provided")]
fn test_panic() {
let input_dims = [1, 1, 4];
calculate_output_size(input_dims, None, None);
}
#[test]
#[should_panic(expected = "Scale factor is too large")]
fn test_large_scale_factor() {
let input_dims = [1, 1, usize::MAX - 1];
calculate_output_size(input_dims, None, Some(2.0));
}
#[test]
fn test_module() {
let input = Tensor::<TestBackend, 3>::random(
[2, 3, 4],
Distribution::Uniform(0.0, 1.0),
&Default::default(),
);
// Test with output_size
let config = Interpolate1dConfig::new().with_output_size(Some(8));
let interpolate = config.init();
let output = interpolate.forward(input.clone());
assert_eq!(output.dims(), [2, 3, 8]);
// Test with scale_factor
let config = Interpolate1dConfig::new().with_scale_factor(Some(0.5));
let interpolate = config.init();
let output = interpolate.forward(input.clone());
assert_eq!(output.dims(), [2, 3, 2]);
// Test with different interpolation mode
let config = Interpolate1dConfig::new()
.with_output_size(Some(6))
.with_mode(InterpolateMode::Linear);
let interpolate = config.init();
let output = interpolate.forward(input);
assert_eq!(output.dims(), [2, 3, 6]);
}
#[test]
fn display() {
let config = Interpolate1dConfig::new().with_output_size(Some(20));
let layer = config.init();
assert_eq!(
alloc::format!("{}", layer),
"Interpolate1d {mode: Nearest, output_size: Some(20), \
scale_factor: None}"
);
}
}

View File

@ -0,0 +1,251 @@
use alloc::format;
use burn_tensor::module::interpolate;
use crate as burn;
use crate::config::Config;
use crate::module::{Content, DisplaySettings, Ignored, Module, ModuleDisplay};
use crate::tensor::backend::Backend;
use crate::tensor::ops::InterpolateOptions;
use crate::tensor::Tensor;
use super::InterpolateMode;
/// Configuration for the 2D interpolation module.
///
/// This struct defines the configuration options for the 2D interpolation operation.
/// It allows specifying the output size, scale factor, and interpolation mode.
#[derive(Config, Debug)]
pub struct Interpolate2dConfig {
/// Output size of the interpolated tensor.
/// If specified, this takes precedence over `scale_factor`.
#[config(default = "None")]
pub output_size: Option<[usize; 2]>,
/// Scale factor for resizing the input tensor.
/// This is used when `output_size` is not specified.
#[config(default = "None")]
pub scale_factor: Option<[f32; 2]>,
/// Interpolation mode to use for resizing.
/// Determines how the output values are calculated.
#[config(default = "InterpolateMode::Nearest")]
pub mode: InterpolateMode,
}
/// Interpolate module for resizing tensors with shape [N, C, H, W].
///
/// This struct represents an interpolation module that can resize tensors
/// using various interpolation methods. It provides flexibility in specifying
/// either an output size or a scale factor for resizing, along with options
/// for the interpolation mode.
///
/// The module can be used to upsample or downsample tensors, preserving the
/// number of channels and batch size while adjusting the height and width
/// dimensions.
///
/// The module can be created using the [Interpolate2dConfig] struct and the
/// `init` method, which returns an instance of the [Interpolate2d] struct.
#[derive(Module, Clone, Debug)]
#[module(custom_display)]
pub struct Interpolate2d {
/// Output size of the interpolated tensor
pub output_size: Option<[usize; 2]>,
/// Scale factor for resizing the input tensor
pub scale_factor: Option<[f32; 2]>,
/// Interpolation mode used for resizing
pub mode: Ignored<InterpolateMode>,
}
impl Interpolate2dConfig {
/// Initialize the interpolation module
pub fn init(self) -> Interpolate2d {
Interpolate2d {
output_size: self.output_size,
scale_factor: self.scale_factor,
mode: Ignored(self.mode),
}
}
}
impl Interpolate2d {
/// Performs the forward pass of the interpolation module
///
/// # Arguments
///
/// * `input` - Input tensor with shape [N, C, H, W]
///
/// # Returns
///
/// Resized tensor with shape [N, C, H', W'], where H' and W' are determined by
/// the output_size or scale_factor specified in the module configuration
///
/// # Example
///
/// ```ignore
/// let input = Tensor::<Backend, 2>::random([1, 3, 64, 64], Distribution::Uniform(0.0, 1.0), &device);
/// let interpolate = Interpolate2dConfig::new()
/// .with_output_size(Some([128, 128]))
/// .init();
/// let output = interpolate.forward(input);
/// assert_eq!(output.dims(), [1, 3, 128, 128]);
/// ```
pub fn forward<B: Backend>(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
let output_size = calculate_output_size(input.dims(), self.output_size, self.scale_factor);
interpolate(
input,
output_size,
InterpolateOptions::new(self.mode.0.clone().into()),
)
}
}
/// Calculates the output size for tensor interpolation.
///
/// # Arguments
///
/// * `input_dims` - The dimensions of the input tensor [N, C, H, W].
/// * `output_size` - Optional desired output size [H', W'].
/// * `scale_factor` - Optional scale factor for height and width [scale_h, scale_w].
///
/// # Returns
///
/// A tuple [H', W'] representing the calculated output size.
///
/// # Panics
///
/// Panics if neither `output_size` nor `scale_factor` is provided,
/// or if the scale factor results in dimensions exceeding usize::MAX.
fn calculate_output_size(
input_dims: [usize; 4],
output_size: Option<[usize; 2]>,
scale_factor: Option<[f32; 2]>,
) -> [usize; 2] {
match (output_size, scale_factor) {
(Some(output_size), None) => {
// Use provided
output_size
}
(None, Some(scale_factor)) => {
// Calculate output size based on scale factor
let [_, _, h, w] = input_dims;
let new_dim_h = (h as f64) * (scale_factor[0] as f64);
if new_dim_h > usize::MAX as f64 {
panic!("Scale factor for height is too large");
}
let new_dim_w = (w as f64) * (scale_factor[1] as f64);
if new_dim_w > usize::MAX as f64 {
panic!("Scale factor for width is too large");
}
[new_dim_h as usize, new_dim_w as usize]
}
_ => panic!("Either output_size or scale_factor must be provided"),
}
}
impl ModuleDisplay for Interpolate2d {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content
.add("mode", &self.mode)
.add("output_size", &format!("{:?}", self.output_size))
.add("scale_factor", &self.scale_factor)
.optional()
}
}
#[cfg(test)]
mod tests {
use burn_tensor::Distribution;
use crate::TestBackend;
use super::*;
#[test]
fn test_calculate_output_size() {
let input_dims = [1, 1, 4, 4];
let output_size = calculate_output_size(input_dims, Some([2, 2]), None);
assert_eq!(output_size, [2, 2]);
let output_size = calculate_output_size(input_dims, None, Some([2.0, 2.0]));
assert_eq!(output_size, [8, 8]);
let output_size = calculate_output_size([1, 1, 4, 4], None, Some([0.5, 0.5]));
assert_eq!(output_size, [2, 2]);
let output_size = calculate_output_size([1, 1, 4, 4], None, Some([2.0, 1.5]));
assert_eq!(output_size, [8, 6]);
}
#[test]
#[should_panic(expected = "Either output_size or scale_factor must be provided")]
fn test_missing_params() {
calculate_output_size([1, 1, 4, 4], None, None);
}
#[test]
#[should_panic(expected = "Scale factor for height is too large")]
fn test_infinite_height() {
calculate_output_size([1, 1, usize::MAX - 1, 4], None, Some([2.0, 1.0]));
}
#[test]
#[should_panic(expected = "Scale factor for width is too large")]
fn test_infinite_width() {
calculate_output_size([1, 1, 4, usize::MAX - 1], None, Some([1.0, 2.0]));
}
#[test]
fn test_module() {
let input = Tensor::<TestBackend, 4>::random(
[2, 3, 4, 4],
Distribution::Uniform(0.0, 1.0),
&Default::default(),
);
// Test with output_size
let config = Interpolate2dConfig::new().with_output_size(Some([8, 8]));
let interpolate = config.init();
let output = interpolate.forward(input.clone());
assert_eq!(output.dims(), [2, 3, 8, 8]);
// Test with scale_factor
let config = Interpolate2dConfig::new().with_scale_factor(Some([0.5, 0.5]));
let interpolate = config.init();
let output = interpolate.forward(input.clone());
assert_eq!(output.dims(), [2, 3, 2, 2]);
// Test with different interpolation mode
let config = Interpolate2dConfig::new()
.with_output_size(Some([6, 6]))
.with_mode(InterpolateMode::Linear);
let interpolate = config.init();
let output = interpolate.forward(input);
assert_eq!(output.dims(), [2, 3, 6, 6]);
}
#[test]
fn display() {
let config = Interpolate2dConfig::new().with_output_size(Some([20, 20]));
let layer = config.init();
assert_eq!(
alloc::format!("{}", layer),
"Interpolate2d {mode: Nearest, output_size: Some([20, 20]), \
scale_factor: None}"
);
}
}

View File

@ -0,0 +1,46 @@
mod interpolate1d;
mod interpolate2d;
pub use interpolate1d::*;
pub use interpolate2d::*;
use crate::tensor::ops::InterpolateMode as OpsInterpolateMode;
/// Algorithm used for downsampling and upsampling
///
/// This enum defines different interpolation modes for resampling data.
#[derive(new, Debug, Clone, serde::Deserialize, serde::Serialize)]
pub enum InterpolateMode {
/// Nearest-neighbor interpolation
///
/// This mode selects the value of the nearest sample point for each output pixel.
/// It is applicable for both temporal and spatial data.
Nearest,
/// Linear interpolation
///
/// This mode calculates the output value using linear
/// interpolation between nearby sample points.
///
/// It is applicable for both temporal and spatial data.
Linear,
/// Cubic interpolation
///
/// This mode uses cubic interpolation to calculate the output value
/// based on surrounding sample points.
///
/// It is applicable for both temporal and spatial data and generally
/// provides smoother results than linear interpolation.
Cubic,
}
impl From<InterpolateMode> for OpsInterpolateMode {
fn from(mode: InterpolateMode) -> Self {
match mode {
InterpolateMode::Nearest => OpsInterpolateMode::Nearest,
InterpolateMode::Linear => OpsInterpolateMode::Bilinear,
InterpolateMode::Cubic => OpsInterpolateMode::Bicubic,
}
}
}

View File

@ -16,6 +16,9 @@ pub mod pool;
/// Transformer module
pub mod transformer;
/// Interpolate module
pub mod interpolate;
mod dropout;
mod embedding;
mod gelu;

View File

@ -6,8 +6,8 @@ fn main() {
// Add onnx models.
ModelGen::new()
.input("tests/add/add_int.onnx")
.input("tests/add/add.onnx")
.input("tests/add/add_int.onnx")
.input("tests/argmax/argmax.onnx")
.input("tests/avg_pool1d/avg_pool1d.onnx")
.input("tests/avg_pool2d/avg_pool2d.onnx")
@ -16,9 +16,13 @@ fn main() {
.input("tests/clip/clip_opset16.onnx")
.input("tests/clip/clip_opset7.onnx")
.input("tests/concat/concat.onnx")
.input("tests/constant_of_shape/constant_of_shape.onnx")
.input("tests/constant_of_shape/constant_of_shape_full_like.onnx")
.input("tests/conv1d/conv1d.onnx")
.input("tests/conv2d/conv2d.onnx")
.input("tests/conv3d/conv3d.onnx")
.input("tests/conv_transpose2d/conv_transpose2d.onnx")
.input("tests/conv_transpose3d/conv_transpose3d.onnx")
.input("tests/cos/cos.onnx")
.input("tests/div/div.onnx")
.input("tests/dropout/dropout_opset16.onnx")
@ -26,70 +30,71 @@ fn main() {
.input("tests/equal/equal.onnx")
.input("tests/erf/erf.onnx")
.input("tests/exp/exp.onnx")
.input("tests/expand/expand.onnx")
.input("tests/flatten/flatten.onnx")
.input("tests/gather/gather.onnx")
.input("tests/gather_elements/gather_elements.onnx")
.input("tests/gelu/gelu.onnx")
.input("tests/global_avr_pool/global_avr_pool.onnx")
.input("tests/greater/greater.onnx")
.input("tests/greater_or_equal/greater_or_equal.onnx")
.input("tests/layer_norm/layer_norm.onnx")
.input("tests/leaky_relu/leaky_relu.onnx")
.input("tests/less/less.onnx")
.input("tests/less_or_equal/less_or_equal.onnx")
.input("tests/linear/linear.onnx")
.input("tests/log_softmax/log_softmax.onnx")
.input("tests/log/log.onnx")
.input("tests/log_softmax/log_softmax.onnx")
.input("tests/mask_where/mask_where.onnx")
.input("tests/matmul/matmul.onnx")
.input("tests/min/min.onnx")
.input("tests/max/max.onnx")
.input("tests/maxpool1d/maxpool1d.onnx")
.input("tests/maxpool2d/maxpool2d.onnx")
.input("tests/min/min.onnx")
.input("tests/mul/mul.onnx")
.input("tests/neg/neg.onnx")
.input("tests/not/not.onnx")
.input("tests/pad/pad.onnx")
.input("tests/expand/expand.onnx")
.input("tests/greater/greater.onnx")
.input("tests/greater_or_equal/greater_or_equal.onnx")
.input("tests/less/less.onnx")
.input("tests/less_or_equal/less_or_equal.onnx")
.input("tests/recip/recip.onnx")
.input("tests/relu/relu.onnx")
.input("tests/leaky_relu/leaky_relu.onnx")
.input("tests/pow/pow.onnx")
.input("tests/pow/pow_int.onnx")
.input("tests/prelu/prelu.onnx")
.input("tests/random_normal/random_normal.onnx")
.input("tests/random_uniform/random_uniform.onnx")
.input("tests/range/range.onnx")
.input("tests/recip/recip.onnx")
.input("tests/reduce_max/reduce_max.onnx")
.input("tests/reduce_min/reduce_min.onnx")
.input("tests/reduce_mean/reduce_mean.onnx")
.input("tests/reduce_min/reduce_min.onnx")
.input("tests/reduce_prod/reduce_prod.onnx")
.input("tests/reduce_sum/reduce_sum_opset13.onnx")
.input("tests/reduce_sum/reduce_sum_opset11.onnx")
.input("tests/reduce_sum/reduce_sum_opset13.onnx")
.input("tests/relu/relu.onnx")
.input("tests/reshape/reshape.onnx")
.input("tests/resize/resize.onnx")
.input("tests/resize/resize_with_sizes.onnx")
.input("tests/resize/resize_1d_linear_scale.onnx")
.input("tests/resize/resize_1d_nearest_scale.onnx")
.input("tests/resize/resize_2d_bicubic_scale.onnx")
.input("tests/resize/resize_2d_bilinear_scale.onnx")
.input("tests/resize/resize_2d_nearest_scale.onnx")
.input("tests/shape/shape.onnx")
.input("tests/sigmoid/sigmoid.onnx")
.input("tests/sign/sign.onnx")
.input("tests/sin/sin.onnx")
.input("tests/slice/slice.onnx")
.input("tests/softmax/softmax.onnx")
.input("tests/sqrt/sqrt.onnx")
.input("tests/sub/sub_int.onnx")
.input("tests/squeeze/squeeze_multiple.onnx")
.input("tests/squeeze/squeeze_opset13.onnx")
.input("tests/squeeze/squeeze_opset16.onnx")
.input("tests/sub/sub.onnx")
.input("tests/tanh/tanh.onnx")
.input("tests/transpose/transpose.onnx")
.input("tests/conv_transpose2d/conv_transpose2d.onnx")
.input("tests/conv_transpose3d/conv_transpose3d.onnx")
.input("tests/pow/pow.onnx")
.input("tests/pow/pow_int.onnx")
.input("tests/slice/slice.onnx")
.input("tests/sub/sub_int.onnx")
.input("tests/sum/sum.onnx")
.input("tests/sum/sum_int.onnx")
.input("tests/tanh/tanh.onnx")
.input("tests/transpose/transpose.onnx")
.input("tests/unsqueeze/unsqueeze.onnx")
.input("tests/unsqueeze/unsqueeze_opset16.onnx")
.input("tests/unsqueeze/unsqueeze_opset11.onnx")
.input("tests/mask_where/mask_where.onnx")
.input("tests/squeeze/squeeze_opset16.onnx")
.input("tests/squeeze/squeeze_opset13.onnx")
.input("tests/squeeze/squeeze_multiple.onnx")
.input("tests/random_uniform/random_uniform.onnx")
.input("tests/random_normal/random_normal.onnx")
.input("tests/constant_of_shape/constant_of_shape.onnx")
.input("tests/constant_of_shape/constant_of_shape_full_like.onnx")
.input("tests/range/range.onnx")
.input("tests/unsqueeze/unsqueeze_opset16.onnx")
.out_dir("model/")
.run_from_script();

View File

@ -15,19 +15,23 @@ macro_rules! include_models {
// ATTENTION: Modify this macro to include all models in the `model` directory.
include_models!(
add_int,
add,
add_int,
argmax,
avg_pool2d,
avg_pool1d,
avg_pool2d,
batch_norm,
cast,
clip_opset16,
clip_opset7,
concat,
constant_of_shape,
constant_of_shape_full_like,
conv1d,
conv2d,
conv3d,
conv_transpose2d,
conv_transpose3d,
cos,
div,
dropout_opset16,
@ -41,37 +45,46 @@ include_models!(
gather_elements,
gelu,
global_avr_pool,
greater,
greater_or_equal,
layer_norm,
leaky_relu,
less,
less_or_equal,
linear,
log_softmax,
log,
log_softmax,
mask_where,
matmul,
min,
max,
maxpool1d,
maxpool2d,
min,
mul,
neg,
not,
pad,
greater,
greater_or_equal,
less,
less_or_equal,
pow,
pow_int,
prelu,
random_normal,
random_uniform,
range,
recip,
reduce_max,
reduce_min,
reduce_mean,
reduce_min,
reduce_prod,
reduce_sum_opset13,
reduce_sum_opset11,
reduce_sum_opset13,
relu,
reshape,
resize,
resize_with_sizes,
resize_1d_linear_scale,
resize_1d_nearest_scale,
resize_2d_bicubic_scale,
resize_2d_bilinear_scale,
resize_2d_nearest_scale,
shape,
sigmoid,
sign,
@ -79,26 +92,18 @@ include_models!(
slice,
softmax,
sqrt,
sub_int,
squeeze_multiple,
squeeze_opset13,
squeeze_opset16,
sub,
sub_int,
sum,
sum_int,
tanh,
transpose,
conv_transpose2d,
conv_transpose3d,
pow,
pow_int,
unsqueeze,
unsqueeze_opset16,
unsqueeze_opset11,
squeeze_opset16,
squeeze_opset13,
squeeze_multiple,
random_uniform,
random_normal,
constant_of_shape,
constant_of_shape_full_like
unsqueeze_opset16
);
#[cfg(test)]
@ -865,10 +870,10 @@ mod tests {
}
#[test]
fn resize() {
fn resize_with_sizes() {
// Initialize the model without weights (because the exported file does not contain them)
let device = Default::default();
let model: resize::Model<Backend> = resize::Model::new(&device);
let model: resize_with_sizes::Model<Backend> = resize_with_sizes::Model::new(&device);
// Run the model
let input = Tensor::<Backend, 4>::from_floats(
@ -880,14 +885,153 @@ mod tests {
]]],
&device,
);
let size = Tensor::<Backend, 1, Int>::from_ints([1, 1, 2, 3], &device);
let output = model.forward(input, size);
// The sizes are [1, 1, 2, 3]
let output = model.forward(input);
let expected = TensorData::from([[[[0.0f32, 1.5, 3.0], [12.0, 13.5, 15.0]]]]);
output.to_data().assert_eq(&expected, true);
}
#[test]
#[ignore = "https://github.com/tracel-ai/burn/issues/2080"]
fn resize_with_scales_1d_linear() {
// Initialize the model without weights (because the exported file does not contain them)
let device = Default::default();
let model: resize_1d_linear_scale::Model<Backend> =
resize_1d_linear_scale::Model::new(&device);
// Run the model
let input = Tensor::<Backend, 3>::from_floats(
[[[1.5410, -0.2934, -2.1788, 0.5684, -1.0845, -1.3986]]],
&device,
);
// The scales are 1.5
let output = model.forward(input);
let output_sum = output.sum().into_scalar();
let expected_sum = -4.568_224; // from pytorch
assert!(expected_sum.approx_eq(output_sum, (1.0e-4, 2)));
}
#[test]
fn resize_with_scales_2d_bilinear() {
// Initialize the model without weights (because the exported file does not contain them)
let device = Default::default();
let model: resize_2d_bilinear_scale::Model<Backend> =
resize_2d_bilinear_scale::Model::new(&device);
// Run the model
let input = Tensor::<Backend, 4>::from_floats(
[[[
[-1.1258, -1.1524, -0.2506, -0.4339, 0.8487, 0.6920],
[-0.3160, -2.1152, 0.3223, -1.2633, 0.3500, 0.3081],
[0.1198, 1.2377, 1.1168, -0.2473, -1.3527, -1.6959],
[0.5667, 0.7935, 0.4397, 0.1124, 0.6408, 0.4412],
[-0.2159, -0.7425, 0.5627, 0.2596, 0.5229, 2.3022],
[-1.4689, -1.5867, 1.2032, 0.0845, -1.2001, -0.0048],
]]],
&device,
);
// The scales are 1.5, 1.5
let output = model.forward(input);
let output_sum = output.sum().into_scalar();
let expected_sum = -3.401_126_6; // from pytorch
assert!(expected_sum.approx_eq(output_sum, (1.0e-4, 2)));
}
#[test]
fn resize_with_scales_2d_nearest() {
// Initialize the model without weights (because the exported file does not contain them)
let device = Default::default();
let model: resize_2d_nearest_scale::Model<Backend> =
resize_2d_nearest_scale::Model::<Backend>::new(&device);
// Run the model
let input = Tensor::<Backend, 4>::from_floats(
[[[
[-1.1258, -1.1524, -0.2506, -0.4339, 0.8487, 0.6920],
[-0.3160, -2.1152, 0.3223, -1.2633, 0.3500, 0.3081],
[0.1198, 1.2377, 1.1168, -0.2473, -1.3527, -1.6959],
[0.5667, 0.7935, 0.4397, 0.1124, 0.6408, 0.4412],
[-0.2159, -0.7425, 0.5627, 0.2596, 0.5229, 2.3022],
[-1.4689, -1.5867, 1.2032, 0.0845, -1.2001, -0.0048],
]]],
&device,
);
// The scales are 1.5, 1.5
let output = model.forward(input);
assert_eq!(output.dims(), [1, 1, 9, 9]);
let output_sum = output.sum().into_scalar();
let expected_sum = -0.812_227_7; // from pytorch
assert!(expected_sum.approx_eq(output_sum, (1.0e-4, 2)));
}
#[test]
fn resize_with_scales_1d_nearest() {
// Initialize the model without weights (because the exported file does not contain them)
let device = Default::default();
let model: resize_1d_nearest_scale::Model<Backend> =
resize_1d_nearest_scale::Model::<Backend>::new(&device);
// Run the model
let input = Tensor::<Backend, 3>::from_floats(
[[[1.5410, -0.2934, -2.1788, 0.5684, -1.0845, -1.3986]]],
&device,
);
// The scales are 1.5, 1.5
let output = model.forward(input);
assert_eq!(output.dims(), [1, 1, 9]);
let output_sum = output.sum().into_scalar();
let expected_sum = -4.568_224; // from pytorch
assert!(expected_sum.approx_eq(output_sum, (1.0e-4, 2)));
}
#[test]
fn resize_with_scales_2d_bicubic() {
// Initialize the model without weights (because the exported file does not contain them)
let device = Default::default();
let model: resize_2d_bicubic_scale::Model<Backend> =
resize_2d_bicubic_scale::Model::<Backend>::new(&device);
// Run the model
let input = Tensor::<Backend, 4>::from_floats(
[[[
[-1.1258, -1.1524, -0.2506, -0.4339, 0.8487, 0.6920],
[-0.3160, -2.1152, 0.3223, -1.2633, 0.3500, 0.3081],
[0.1198, 1.2377, 1.1168, -0.2473, -1.3527, -1.6959],
[0.5667, 0.7935, 0.4397, 0.1124, 0.6408, 0.4412],
[-0.2159, -0.7425, 0.5627, 0.2596, 0.5229, 2.3022],
[-1.4689, -1.5867, 1.2032, 0.0845, -1.2001, -0.0048],
]]],
&device,
);
// The scales are 1.5, 1.5
let output = model.forward(input);
assert_eq!(output.dims(), [1, 1, 9, 9]);
let output_sum = output.sum().into_scalar();
let expected_sum = -3.515_921; // from pytorch
assert!(expected_sum.approx_eq(output_sum, (1.0e-3, 2)));
}
#[test]
fn shape() {
let device = Default::default();

View File

@ -0,0 +1,71 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torch.onnx
class InterpolateModel(nn.Module):
def __init__(self, scale_factor=None, size=None, mode='nearest', align_corners=None):
super(InterpolateModel, self).__init__()
self.scale_factor = scale_factor
self.size = size
self.mode = mode
self.align_corners = align_corners
def forward(self, x):
return nn.functional.interpolate(x, scale_factor=self.scale_factor, size=self.size,
mode=self.mode, align_corners=self.align_corners)
def export_interpolate_onnx(filename, batch_size=1, channels=1, height=6, width=6,
scale_factor=None, size=None, mode='nearest', dim=2, align_corners=None):
model = InterpolateModel(scale_factor, size, mode, align_corners)
model.eval()
# Add seed for reproducibility
torch.manual_seed(0)
# Create a dummy input
if dim == 1:
dummy_input = torch.randn(batch_size, channels, width)
elif dim == 2:
dummy_input = torch.randn(batch_size, channels, height, width)
else:
raise ValueError("Unsupported dimension. Use 1 for temporal or 2 for spatial.")
# Export the model
torch.onnx.export(model, dummy_input, filename,
input_names=['input'], output_names=['output'],
dynamic_axes={'input': {0: 'batch_size'},
'output': {0: 'batch_size'}},
opset_version=17)
output = model(dummy_input)
print(f"Input shape: {dummy_input.shape}")
print(f"Output shape: {output.shape}")
# Print sum data
print(f"Input sum: {dummy_input.sum()}")
print(f"Output sum: {output.sum()}")
print(f"Input: {dummy_input}")
print(f"Output: {output}")
print(f"Model exported to {filename}")
print()
# Usage examples:
if __name__ == "__main__":
# 1D (temporal) examples
export_interpolate_onnx("resize_1d_nearest_scale.onnx", scale_factor=1.5, mode='nearest', dim=1)
export_interpolate_onnx("resize_1d_linear_scale.onnx", scale_factor=1.5, mode='linear', dim=1, align_corners=True)
# Cubic interpolation is not supported for 1D tensors
# export_interpolate_onnx("resize_1d_cubic_scale.onnx", scale_factor=1.5, mode='cubic', dim=1)
# 2D (spatial) examples
export_interpolate_onnx("resize_2d_nearest_scale.onnx", scale_factor=1.5, mode='nearest', dim=2)
export_interpolate_onnx("resize_2d_bilinear_scale.onnx", scale_factor=1.5, mode='bilinear', dim=2, align_corners=True)
export_interpolate_onnx("resize_2d_bicubic_scale.onnx", scale_factor=1.5, mode='bicubic', dim=2, align_corners=True)

View File

@ -4,10 +4,19 @@
import onnx
from onnx import helper, TensorProto
import numpy as np
def main() -> None:
input_tensor = helper.make_tensor_value_info("input_tensor", TensorProto.FLOAT, [1, 1, 4, 4])
sizes_tensor = helper.make_tensor_value_info("sizes", TensorProto.INT64, [4])
# Create sizes as a constant tensor
sizes = np.array([1, 1, 2, 3], dtype=np.int64)
sizes_tensor = helper.make_tensor(
name="sizes",
data_type=TensorProto.INT64,
dims=sizes.shape,
vals=sizes.flatten().tolist(),
)
resize_node = helper.make_node(
"Resize",
@ -20,15 +29,16 @@ def main() -> None:
graph_def = helper.make_graph(
nodes=[resize_node],
name="ResizeGraph",
inputs=[input_tensor, sizes_tensor],
inputs=[input_tensor],
outputs=[
helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, 1, 2, 2])
],
initializer=[sizes_tensor],
)
model_def = helper.make_model(graph_def, producer_name="resize")
onnx.save(model_def, "resize.onnx")
onnx.save(model_def, "resize_with_sizes.onnx")
if __name__ == "__main__":

View File

@ -64,6 +64,13 @@ impl ToTokens for f64 {
}
}
/// Prettier output for `f32`
impl ToTokens for f32 {
fn to_tokens(&self) -> TokenStream {
convert_primitive(self)
}
}
/// Padding configuration
impl ToTokens for PaddingConfig1d {
fn to_tokens(&self) -> TokenStream {

View File

@ -1,29 +1,17 @@
use super::{Node, NodeCodegen};
use crate::burn::{OtherType, Scope, TensorType, Type};
use burn::module::Module;
use crate::burn::{OtherType, Scope, TensorType, ToTokens, Type};
use burn::record::PrecisionSettings;
use proc_macro2::TokenStream;
use quote::quote;
#[derive(Module, Debug, Clone)]
pub enum ResizeMode {
Nearest,
Linear,
Cubic,
}
#[derive(new, Module, Debug, Clone)]
pub struct ResizeOptions {
pub mode: ResizeMode,
}
#[derive(Debug, Clone)]
pub struct ResizeNode {
pub field: OtherType,
pub input: TensorType,
pub output: TensorType,
pub output_size: TensorType,
pub config: ResizeOptions,
mode: String,
scales: Vec<f32>,
sizes: Vec<usize>,
}
impl ResizeNode {
@ -31,20 +19,29 @@ impl ResizeNode {
name: S,
input: TensorType,
output: TensorType,
output_size: TensorType,
config: ResizeOptions,
mode: String,
scales: Vec<f32>,
sizes: Vec<usize>,
) -> Self {
let ty = if input.dim == 3 {
quote! {
Interpolate1d
}
} else if input.dim == 4 {
quote! {
Interpolate2d
}
} else {
panic!("Unsupported input dimension for resize node");
};
Self {
field: OtherType::new(
name,
quote! {
burn::module::Ignored<InterpolateOptions>
},
),
field: OtherType::new(name, ty),
input,
output,
output_size,
config,
mode,
scales,
sizes,
}
}
}
@ -55,10 +52,7 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for ResizeNode {
}
fn input_types(&self) -> Vec<Type> {
vec![
Type::Tensor(self.input.clone()),
Type::Tensor(self.output_size.clone()),
]
vec![Type::Tensor(self.input.clone())]
}
fn field_type(&self) -> Option<Type> {
@ -68,59 +62,96 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for ResizeNode {
fn field_init(&self) -> Option<TokenStream> {
let name = &self.field.name;
let mode = match self.config.mode {
ResizeMode::Linear => quote! { InterpolateMode::Bilinear },
ResizeMode::Nearest => quote! { InterpolateMode::Nearest },
ResizeMode::Cubic => quote! { InterpolateMode::Bicubic },
let mode = match self.mode.as_str() {
"nearest" => quote! { InterpolateMode::Nearest },
"linear" => quote! { InterpolateMode::Linear },
"cubic" => quote! { InterpolateMode::Cubic },
_ => panic!("Unsupported mode for resize node"),
};
let tokens = quote! {
let #name = InterpolateOptions {
mode: #mode,
let tokens = if self.input.dim == 3 {
let size = if let Some(size) = self.sizes.first() {
let size = size.to_tokens();
quote! { Some(#size) }
} else {
quote! { None }
};
let #name = burn::module::Ignored(#name);
let scale_factor = if let Some(scale) = self.scales.first() {
let scale = scale.to_tokens();
quote! { Some(#scale) }
} else {
quote! { None }
};
quote! {
let #name = Interpolate1dConfig::new()
.with_output_size(#size)
.with_scale_factor(#scale_factor)
.with_mode(#mode)
.init();
}
} else if self.input.dim == 4 {
let size = if self.sizes.len() == 2 {
let h = self.sizes[0].to_tokens();
let w = self.sizes[1].to_tokens();
quote! { Some([#h, #w]) }
} else {
quote! { None }
};
let scale_factor = if self.scales.len() == 2 {
let h = self.scales[0].to_tokens();
let w = self.scales[1].to_tokens();
quote! { Some([#h, #w]) }
} else {
quote! { None }
};
quote! {
let #name = Interpolate2dConfig::new()
.with_output_size(#size)
.with_scale_factor(#scale_factor)
.with_mode(#mode)
.init();
}
} else {
panic!("Unsupported input dimension for resize node");
};
Some(tokens)
}
fn register_imports(&self, imports: &mut crate::burn::BurnImports) {
imports.register("burn::nn::interpolate::InterpolateMode");
if self.input.dim == 3 {
imports.register("burn::nn::interpolate::Interpolate1dConfig");
imports.register("burn::nn::interpolate::Interpolate1d");
} else if self.input.dim == 4 {
imports.register("burn::nn::interpolate::Interpolate2dConfig");
imports.register("burn::nn::interpolate::Interpolate2d");
} else {
panic!("Unsupported input dimension for resize node");
}
}
fn field_serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
S::serialize_none(serializer)
}
fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
let input = scope.tensor_use_owned(&self.input, node_position);
let output_size = scope.tensor_use_owned(&self.output_size, node_position);
let output = &self.output.name;
let field = &self.field.name;
quote! {
let output_size_data = #output_size.to_data();
let mut output_size = [0usize; 2];
for (i, &x) in output_size_data.as_slice::<B::IntElem>().unwrap().iter().rev().take(2).rev().enumerate() {
output_size[i] = x.elem::<i64>() as usize;
}
let #output = interpolate(
#input,
output_size,
self.#field.0.clone(),
);
let #output = self.#field.forward(#input);
}
}
fn into_node(self) -> Node<PS> {
Node::Resize(self)
}
fn register_imports(&self, imports: &mut crate::burn::BurnImports) {
imports.register("burn::tensor::ElementConversion");
imports.register("burn::tensor::module::interpolate");
imports.register("burn::tensor::ops::InterpolateMode");
imports.register("burn::tensor::ops::InterpolateOptions");
}
}
#[cfg(test)]
@ -135,47 +166,42 @@ mod tests {
};
#[test]
fn test_codegen_nodes() {
fn test_codegen_nodes_2d() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
graph.register(ResizeNode::new(
"resize",
TensorType::new_float("tensor1", 4),
TensorType::new_float("tensor2", 4),
TensorType::new_int("output_size", 1),
ResizeOptions::new(ResizeMode::Linear),
"nearest".to_string(),
vec![0.5, 0.5],
vec![],
));
graph.register_input_output(
vec!["tensor1".to_string(), "output_size".to_string()],
vec!["tensor2".to_string()],
);
graph.register_input_output(vec!["tensor1".to_string()], vec!["tensor2".to_string()]);
let expected = quote! {
use burn::tensor::module::interpolate;
use burn::tensor::ops::InterpolateMode;
use burn::tensor::ops::InterpolateOptions;
use burn::tensor::ElementConversion;
use burn::tensor::Int;
use burn::nn::interpolate::Interpolate2d;
use burn::nn::interpolate::Interpolate2dConfig;
use burn::nn::interpolate::InterpolateMode;
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
resize: burn::module::Ignored<InterpolateOptions>,
resize: Interpolate2d,
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}
impl<B: Backend> Model <B> {
impl<B: Backend> Model<B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
let resize = InterpolateOptions {
mode: InterpolateMode::Bilinear,
};
let resize = burn::module::Ignored(resize);
let resize = Interpolate2dConfig::new()
.with_output_size(None)
.with_scale_factor(Some([0.5, 0.5]))
.with_mode(InterpolateMode::Nearest)
.init();
Self {
resize,
phantom: core::marker::PhantomData,
@ -183,20 +209,62 @@ mod tests {
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(
&self,
tensor1: Tensor<B, 4>,
output_size: Tensor<B, 1, Int>
) -> Tensor<B, 4> {
let output_size_data = output_size.to_data();
let mut output_size = [0usize; 2];
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
let tensor2 = self.resize.forward(tensor1);
tensor2
}
}
};
for (i, &x) in output_size_data.as_slice::<B::IntElem>().unwrap().iter().rev().take(2).rev().enumerate() {
output_size[i] = x.elem::<i64>() as usize;
assert_tokens(graph.codegen(), expected);
}
#[test]
fn test_codegen_nodes_1d() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
graph.register(ResizeNode::new(
"resize",
TensorType::new_float("tensor1", 3),
TensorType::new_float("tensor2", 3),
"cubic".to_string(),
vec![],
vec![20],
));
graph.register_input_output(vec!["tensor1".to_string()], vec!["tensor2".to_string()]);
let expected = quote! {
use burn::nn::interpolate::Interpolate1d;
use burn::nn::interpolate::Interpolate1dConfig;
use burn::nn::interpolate::InterpolateMode;
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
resize: Interpolate1d,
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}
impl<B: Backend> Model<B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
let resize = Interpolate1dConfig::new()
.with_output_size(Some(20))
.with_scale_factor(None)
.with_mode(InterpolateMode::Cubic)
.init();
Self {
resize,
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
let tensor2 = interpolate(tensor1, output_size, self.resize.0.clone());
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, tensor1: Tensor<B, 3>) -> Tensor<B, 3> {
let tensor2 = self.resize.forward(tensor1);
tensor2
}
}

View File

@ -7,7 +7,7 @@ use burn::nn::{
PaddingConfig2d, PaddingConfig3d,
};
use crate::burn::node::{pad::PadConfig, resize::ResizeMode};
use crate::burn::node::pad::PadConfig;
use onnx_ir::ir::{ArgType, AttributeValue, Data, Node};
/// Create a Conv1dConfig from the attributes of the node
@ -976,26 +976,132 @@ pub fn reshape_config(node: &Node) -> Vec<i64> {
}
}
pub fn resize_config(node: &Node) -> ResizeMode {
pub fn resize_config(node: &Node) -> (String, Vec<f32>, Vec<usize>) {
let mut mode: String = "".to_string();
let mut scales: Vec<f32>;
let mut sizes: Vec<usize>;
let input = if let ArgType::Tensor(tensor) = &node
.inputs
.first()
.expect("Resize: Input tensor must be present")
.ty
{
tensor
} else {
panic!("Resize: input must be a tensor")
};
// Note: we are ignoring some attributes because results are approximately the same
// and we are not supporting all the attributes of the Resize operator.
// However, some attributes are important to be checked and we are checking
// against the default values of the attributes.
// TODO revisit this when we have more Resize operators in the model
for (key, value) in node.attrs.iter() {
match key.as_str() {
"coordinate_transformation_mode" => {}
"cubic_coeff_a" => {}
"mode" => mode = value.clone().into_string(),
"nearest_mode" => {}
"antialias" => assert_eq!(
value.clone().into_i32(),
0,
"Resize: antialias other than 0 is not supported"
),
"axes" => panic!("Resize: custom axes attribute is not supported"),
"coordinate_transformation_mode" => {
log::warn!("Resize: coordinate_transformation_mode is ignored")
}
"cubic_coeff_a" => log::warn!("Resize: cubic_coeff_a is ignored"),
"exclude_outside" => assert_eq!(
value.clone().into_i32(),
0,
"Resize: exclude_outside other than 0 is not supported"
),
"extrapolation_value" => assert_eq!(
value.clone().into_f32(),
0.0,
"Resize: extrapolation_value other than 0.0 is not supported"
),
"keep_aspect_ratio_policy" => {
assert_eq!(
value.clone().into_string().to_lowercase(),
"stretch",
"Resize: keep_aspect_ratio_policy other than 'stretch' is not supported"
)
}
"mode" => mode = value.clone().into_string().to_lowercase(),
"nearest_mode" => log::warn!("Resize: nearest_mode is ignored"),
_ => {}
}
}
let mode = match mode.as_str() {
"nearest" => ResizeMode::Nearest,
"linear" => ResizeMode::Linear,
"cubic" => ResizeMode::Cubic,
_ => panic!("Resize: invalid mode string, must be 'nearest', 'linear', or 'cubic'"),
};
let roi: Vec<f32> = node
.inputs
.get(1)
.map(|input| {
if let Some(data) = &input.value {
data.clone().into_f32s()
} else {
vec![]
}
})
.unwrap_or_default();
mode
scales = node
.inputs
.get(2)
.map(|input| {
if let Some(data) = &input.value {
data.clone().into_f32s()
} else {
vec![]
}
})
.unwrap_or_default();
sizes = node
.inputs
.get(3)
.map(|input| {
if let Some(data) = &input.value {
data.clone()
.into_i64s()
.iter()
.map(|&x| x as usize)
.collect()
} else {
vec![]
}
})
.unwrap_or_default();
if mode.is_empty() {
panic!("Resize: mode attribute is required")
}
if !roi.is_empty() {
panic!("Resize: roi input is not supported")
}
if scales.is_empty() && sizes.is_empty() {
panic!("Resize: either scales or sizes input is required")
}
if !scales.is_empty() {
assert!(scales.len() == input.dim);
// ignore the fist two items from scales
// because they are the batch and channel dimensions
scales = scales.iter().skip(2).cloned().collect();
}
if !sizes.is_empty() {
assert!(sizes.len() == input.dim);
// ignore the fist two items from sizes
// because they are the batch and channel dimensions
sizes = sizes.iter().skip(2).cloned().collect();
}
(mode, scales, sizes)
}
//Note this function should only execute if the second input is a constant

View File

@ -46,7 +46,7 @@ use crate::{
random_uniform::RandomUniformNode,
range::RangeNode,
reshape::ReshapeNode,
resize::{ResizeNode, ResizeOptions},
resize::ResizeNode,
slice::SliceNode,
squeeze::SqueezeNode,
sum::SumNode,
@ -646,13 +646,12 @@ impl ParsedOnnxGraph {
let name = &node.name;
let input = TensorType::from(&node.inputs[0]);
let output_size = TensorType::from(&node.inputs[3]);
let output = TensorType::from(node.outputs.first().unwrap());
let mode = resize_config(&node);
let (mode, scales, sizes) = resize_config(&node);
ResizeNode::new(name, input, output, output_size, ResizeOptions { mode })
ResizeNode::new(name, input, output, mode, scales, sizes)
}
fn min_conversion(node: Node) -> BinaryNode {

View File

@ -128,7 +128,7 @@ pub struct UnfoldOptions {
}
/// Algorithm used for upsampling.
#[derive(new, Debug, Clone)]
#[derive(new, Debug, Clone, serde::Deserialize, serde::Serialize)]
pub enum InterpolateMode {
/// Nearest-neighbor interpolation.
/// <https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation>

View File

@ -85,6 +85,47 @@ mod tests {
]]]));
}
#[test]
#[ignore = "https://github.com/tracel-ai/burn/issues/2080"]
fn test_1d_bicubic() {
// Initialize the model without weights (because the exported file does not contain them)
let device = Default::default();
// Run the model
let input = TestTensor::<3>::from_floats(
[[[1.5410, -0.2934, -2.1788, 0.5684, -1.0845, -1.3986]]],
&device,
);
let input = input.unsqueeze_dim(2);
let output = interpolate(
input,
[1, 9],
InterpolateOptions::new(InterpolateMode::Bicubic),
);
assert_eq!(output.dims(), [1, 1, 1, 9]);
// assert output data does not contain NaN
assert!(
!output
.clone()
.to_data()
.as_slice::<f32>()
.unwrap()
.iter()
.any(|&x| x.is_nan()),
"interpolate output contains NaN"
);
TestTensor::<4>::from([[[[
1.541, 0.5747652, -1.010614, -2.197787, -0.8269969, 0.59609234, -0.5803058, -1.3792794,
-1.3986,
]]]])
.to_data()
.assert_approx_eq(&output.into_data(), 3);
}
struct InterpolateTestCase {
batch_size: usize,
channels: usize,

View File

@ -85,6 +85,55 @@ mod tests {
]]]));
}
#[test]
#[ignore = "https://github.com/tracel-ai/burn/issues/2080"]
fn test_1d_bilinear() {
// Initialize the model without weights (because the exported file does not contain them)
let device = Default::default();
// Run the model
let input = TestTensor::<3>::from_floats(
[[[1.5410, -0.2934, -2.1788, 0.5684, -1.0845, -1.3986]]],
&device,
);
let input = input.unsqueeze_dim(2);
let output = interpolate(
input,
[1, 9],
InterpolateOptions::new(InterpolateMode::Bilinear),
);
assert_eq!(output.dims(), [1, 1, 1, 9]);
// assert output data does not contain NaN
assert!(
!output
.clone()
.to_data()
.as_slice::<f32>()
.unwrap()
.iter()
.any(|&x| x.is_nan()),
"interpolate output contains NaN"
);
TestTensor::<4>::from([[[[
1.541f32,
0.39450002,
-0.76475,
-1.943125,
-0.80520004,
0.36178753,
-0.671275,
-1.2022874,
-1.3986,
]]]])
.to_data()
.assert_approx_eq(&output.into_data(), 3);
}
struct InterpolateTestCase {
batch_size: usize,
channels: usize,

View File

@ -59,6 +59,45 @@ mod tests {
]]]));
}
#[test]
fn test_1d_nearest() {
// Initialize the model without weights (because the exported file does not contain them)
let device = Default::default();
// Run the model
let input = TestTensor::<3>::from_floats(
[[[1.5410, -0.2934, -2.1788, 0.5684, -1.0845, -1.3986]]],
&device,
);
let input = input.unsqueeze_dim(2);
let output = interpolate(
input,
[1, 9],
InterpolateOptions::new(InterpolateMode::Nearest),
);
assert_eq!(output.dims(), [1, 1, 1, 9]);
// assert output data does not contain NaN
assert!(
!output
.clone()
.to_data()
.as_slice::<f32>()
.unwrap()
.iter()
.any(|&x| x.is_nan()),
"interpolate output contains NaN"
);
TestTensor::<4>::from([[[[
1.541, 1.541, -0.2934, -2.1788, -2.1788, 0.5684, -1.0845, -1.0845, -1.3986,
]]]])
.to_data()
.assert_approx_eq(&output.into_data(), 3);
}
struct InterpolateTestCase {
batch_size: usize,
channels: usize,

View File

@ -62,7 +62,7 @@ pub fn dim_inference(node: &mut Node) {
NodeType::ReduceSum => reduce_sum_update_outputs(node),
NodeType::Relu => same_as_input(node),
NodeType::Reshape => reshape_update_outputs(node),
NodeType::Resize => resize_update_outputs(node),
NodeType::Resize => same_as_input(node),
NodeType::Shape => shape_update_outputs(node),
NodeType::Sigmoid => same_as_input(node),
NodeType::Sign => same_as_input(node),
@ -318,33 +318,6 @@ fn reshape_update_outputs(node: &mut Node) {
}
}
fn resize_update_outputs(node: &mut Node) {
let input = match &node.inputs[0].ty {
ArgType::Tensor(tensor) => tensor.clone(),
_ => panic!("Resize: invalid input type"),
};
let output = match &node.outputs[0].ty {
ArgType::Tensor(tensor) => tensor.clone(),
_ => panic!("Resize: invalid output type"),
};
let output_size = match &node.inputs[3].ty {
ArgType::Tensor(output_size) => output_size.clone(),
_ => panic!("Resize: invalid output_size type"),
};
if output_size.dim != 1 {
panic!("Resize: output_size must be 1D");
}
node.outputs[0].ty = ArgType::Tensor(TensorType {
dim: input.dim,
shape: None, // shape is calculated at runtime
..output
});
}
fn greater_update_outputs(node: &mut Node) {
match &node.inputs[0].ty {
ArgType::Tensor(tensor) => {
@ -838,7 +811,7 @@ fn gather_update_outputs(node: &mut Node) {
let input_tensor = match &node.inputs[0].ty {
ArgType::Tensor(tensor) => tensor,
_ => panic!("Only tensor input is valid"),
ty => panic!("Only tensor input is valid but received: {:?}", ty),
};
let indices_tensor = match &node.inputs[1].ty {

View File

@ -523,44 +523,54 @@ impl Data {
_ => self,
}
}
pub fn into_f16(self) -> f16 {
if let Data::Float16(elem) = self {
elem
} else {
panic!("Expected Float16, got {:?}", self);
match self {
Data::Float16(elem) => elem,
Data::Float32(elem) => f16::from_f32(elem),
Data::Float64(elem) => f16::from_f64(elem),
_ => panic!("Cannot convert {:?} to f16", self),
}
}
pub fn into_f32(self) -> f32 {
if let Data::Float32(elem) = self {
elem
} else {
panic!("Expected Float32, got {:?}", self);
match self {
Data::Float16(elem) => elem.to_f32(),
Data::Float32(elem) => elem,
Data::Float64(elem) => elem as f32,
Data::Int32(elem) => elem as f32,
Data::Int64(elem) => elem as f32,
_ => panic!("Cannot convert {:?} to f32", self),
}
}
pub fn into_f64(self) -> f64 {
if let Data::Float64(elem) = self {
elem
} else {
panic!("Expected Float64, got {:?}", self);
match self {
Data::Float16(elem) => elem.to_f64(),
Data::Float32(elem) => elem as f64,
Data::Float64(elem) => elem,
Data::Int32(elem) => elem as f64,
Data::Int64(elem) => elem as f64,
_ => panic!("Cannot convert {:?} to f64", self),
}
}
pub fn into_i32(self) -> i32 {
if let Data::Int32(elem) = self {
elem
} else {
panic!("Expected Int32, got {:?}", self);
match self {
Data::Int32(elem) => elem,
Data::Int64(elem) => elem as i32,
Data::Float32(elem) => elem as i32,
Data::Float64(elem) => elem as i32,
_ => panic!("Cannot convert {:?} to i32", self),
}
}
pub fn into_i64(self) -> i64 {
if let Data::Int64(elem) = self {
elem
} else {
panic!("Expected Int64, got {:?}", self);
match self {
Data::Int32(elem) => elem as i64,
Data::Int64(elem) => elem,
Data::Float32(elem) => elem as i64,
Data::Float64(elem) => elem as i64,
_ => panic!("Cannot convert {:?} to i64", self),
}
}
@ -581,42 +591,53 @@ impl Data {
}
pub fn into_f16s(self) -> Vec<f16> {
if let Data::Float16s(elem) = self {
elem
} else {
panic!("Expected Float16s, got {:?}", self);
match self {
Data::Float16s(elem) => elem,
Data::Float32s(elem) => elem.into_iter().map(f16::from_f32).collect(),
Data::Float64s(elem) => elem.into_iter().map(f16::from_f64).collect(),
_ => panic!("Cannot convert {:?} to Vec<f16>", self),
}
}
pub fn into_f32s(self) -> Vec<f32> {
if let Data::Float32s(elem) = self {
elem
} else {
panic!("Expected Float32s, got {:?}", self);
match self {
Data::Float16s(elem) => elem.into_iter().map(|x| x.to_f32()).collect(),
Data::Float32s(elem) => elem,
Data::Float64s(elem) => elem.into_iter().map(|x| x as f32).collect(),
Data::Int32s(elem) => elem.into_iter().map(|x| x as f32).collect(),
Data::Int64s(elem) => elem.into_iter().map(|x| x as f32).collect(),
_ => panic!("Cannot convert {:?} to Vec<f32>", self),
}
}
pub fn into_f64s(self) -> Vec<f64> {
if let Data::Float64s(elem) = self {
elem
} else {
panic!("Expected Float64s, got {:?}", self);
match self {
Data::Float16s(elem) => elem.into_iter().map(|x| x.to_f64()).collect(),
Data::Float32s(elem) => elem.into_iter().map(|x| x as f64).collect(),
Data::Float64s(elem) => elem,
Data::Int32s(elem) => elem.into_iter().map(|x| x as f64).collect(),
Data::Int64s(elem) => elem.into_iter().map(|x| x as f64).collect(),
_ => panic!("Cannot convert {:?} to Vec<f64>", self),
}
}
pub fn into_i32s(self) -> Vec<i32> {
if let Data::Int32s(elem) = self {
elem
} else {
panic!("Expected Int32s, got {:?}", self);
match self {
Data::Int32s(elem) => elem,
Data::Int64s(elem) => elem.into_iter().map(|x| x as i32).collect(),
Data::Float32s(elem) => elem.into_iter().map(|x| x as i32).collect(),
Data::Float64s(elem) => elem.into_iter().map(|x| x as i32).collect(),
_ => panic!("Cannot convert {:?} to Vec<i32>", self),
}
}
pub fn into_i64s(self) -> Vec<i64> {
if let Data::Int64s(elem) = self {
elem
} else {
panic!("Expected Int64s, got {:?}", self);
match self {
Data::Int32s(elem) => elem.into_iter().map(|x| x as i64).collect(),
Data::Int64s(elem) => elem,
Data::Float32s(elem) => elem.into_iter().map(|x| x as i64).collect(),
Data::Float64s(elem) => elem.into_iter().map(|x| x as i64).collect(),
_ => panic!("Cannot convert {:?} to Vec<i64>", self),
}
}