ONNX support for scalar unsqueeze (#1690)

* Revert 1c639c8393

1c639c8393?diff=unified&w=0

* Refactor by @laggui

* Refactor unsqueeze

* Add support for scalar unsqueeze

* Removed dead comment
This commit is contained in:
Dilshod Tadjibaev 2024-04-25 16:05:28 -05:00 committed by GitHub
parent 599a20d586
commit 67ec06d5d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 176 additions and 31 deletions

View File

@ -208,6 +208,67 @@ impl<B: AutodiffBackend> AutodiffModule<B> for PhantomData<B> {
}
}
/// Container to satisfy the Module trait for types that are not modules.
#[derive(Clone, Debug)]
pub struct Ignored<T>(pub T);
impl<B, T> Module<B> for Ignored<T>
where
B: Backend,
T: Sync + Send + core::fmt::Debug + Clone,
{
type Record = ConstantRecord;
fn visit<V: ModuleVisitor<B>>(&self, _visitor: &mut V) {
// Nothing to do
}
fn map<M: ModuleMapper<B>>(self, _mapper: &mut M) -> Self {
self
}
fn load_record(self, _record: Self::Record) -> Self {
self
}
fn into_record(self) -> Self::Record {
ConstantRecord::new()
}
fn to_device(self, _: &<B as Backend>::Device) -> Self {
self
}
fn fork(self, _: &<B as Backend>::Device) -> Self {
self
}
fn collect_devices(&self, devices: Devices<B>) -> Devices<B> {
devices
}
}
impl<B: AutodiffBackend, T> AutodiffModule<B> for Ignored<T>
where
B: AutodiffBackend,
T: Sync + Send + core::fmt::Debug + Clone,
{
type InnerModule = Ignored<T>;
fn valid(&self) -> Self::InnerModule {
self.clone()
}
}
// Implement deref for Ignored
impl<T> core::ops::Deref for Ignored<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[cfg(all(test, feature = "std"))]
mod tests {
use core::marker::PhantomData;

View File

@ -1068,8 +1068,9 @@ mod tests {
let input_shape = Shape::from([3, 4, 5]);
let expected_shape = Shape::from([3, 4, 5, 1]);
let input = Tensor::ones(input_shape, &device);
let output = model.forward(input);
assert_eq!(expected_shape, output.shape());
let output = model.forward(input, 1.0);
assert_eq!(expected_shape, output.0.shape());
assert_eq!(Shape::from([1]), output.1.shape());
}
#[test]
@ -1079,8 +1080,9 @@ mod tests {
let input_shape = Shape::from([3, 4, 5]);
let expected_shape = Shape::from([3, 4, 5, 1]);
let input = Tensor::ones(input_shape, &device);
let output = model.forward(input);
assert_eq!(expected_shape, output.shape());
let output = model.forward(input, 1.0);
assert_eq!(expected_shape, output.0.shape());
assert_eq!(Shape::from([1]), output.1.shape());
}
#[test]

View File

@ -11,9 +11,10 @@ class Model(nn.Module):
super(Model, self).__init__()
self.axis = 3
def forward(self, x):
def forward(self, x, scalar):
x = torch.unsqueeze(x, self.axis)
return x
y = torch.unsqueeze(torch.tensor(scalar), 0)
return x, y
def main():
@ -27,21 +28,21 @@ def main():
model.eval()
device = torch.device("cpu")
test_input = torch.randn(3, 4, 5, device=device)
test_input = (torch.randn(3, 4, 5, device=device),1.0)
model = Model()
output = model.forward(test_input)
output = model.forward(*test_input)
torch.onnx.export(model, (test_input), "unsqueeze_opset16.onnx", verbose=False, opset_version=16)
torch.onnx.export(model, (test_input), "unsqueeze_opset11.onnx", verbose=False, opset_version=11)
torch.onnx.export(model, test_input, "unsqueeze_opset16.onnx", verbose=False, opset_version=16)
torch.onnx.export(model, test_input, "unsqueeze_opset11.onnx", verbose=False, opset_version=11)
print(f"Finished exporting model")
# Output some test data for use in the test
print(f"Test input data of ones: {test_input}")
print(f"Test input data shape of ones: {test_input.shape}")
print(f"Test input data shape of ones: {test_input[0].shape}")
# output = model.forward(test_input)
print(f"Test output data shape: {output.shape}")
print(f"Test output data shape: {output[0].shape}")
print(f"Test output: {output}")

View File

@ -414,6 +414,7 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
// Extend with phantom data to avoid unused generic type.
body.extend(quote! {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
});
quote! {
@ -447,6 +448,7 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
Self {
#(#fields,)*
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
}

View File

@ -127,6 +127,7 @@ mod tests {
pub struct Model <B: Backend> {
avg_pool2d: AvgPool2d,
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}
impl<B: Backend> Model <B> {
@ -141,6 +142,7 @@ mod tests {
Self {
avg_pool2d,
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]

View File

@ -240,6 +240,7 @@ pub(crate) mod tests {
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}
impl<B: Backend> Model <B> {
@ -247,6 +248,7 @@ pub(crate) mod tests {
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
@ -294,6 +296,7 @@ pub(crate) mod tests {
pub struct Model <B: Backend> {
conv2d: Conv2d<B>,
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}
impl<B: Backend> Model <B> {
@ -310,6 +313,7 @@ pub(crate) mod tests {
Self {
conv2d,
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
@ -370,6 +374,7 @@ pub(crate) mod tests {
pub struct Model <B: Backend> {
conv2d: Conv2d<B>,
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}
impl<B: Backend> Model <B> {
@ -386,6 +391,7 @@ pub(crate) mod tests {
Self {
conv2d,
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]

View File

@ -188,6 +188,7 @@ mod tests {
pub struct Model <B: Backend> {
norm: BatchNorm<B, 2>,
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}
impl<B: Backend> Model <B> {
@ -201,6 +202,7 @@ mod tests {
Self {
norm,
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]

View File

@ -354,6 +354,7 @@ mod tests {
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}
impl<B: Backend> Model <B> {
@ -361,6 +362,7 @@ mod tests {
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}

View File

@ -78,6 +78,7 @@ mod tests {
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}
impl<B: Backend> Model <B> {
@ -85,6 +86,7 @@ mod tests {
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
@ -121,6 +123,7 @@ mod tests {
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}
impl<B: Backend> Model <B> {
@ -128,6 +131,7 @@ mod tests {
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
@ -164,6 +168,7 @@ mod tests {
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}
impl<B: Backend> Model <B> {
@ -171,6 +176,7 @@ mod tests {
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]

View File

@ -82,6 +82,7 @@ mod tests {
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}
impl<B: Backend> Model <B> {
@ -89,6 +90,7 @@ mod tests {
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}

View File

@ -164,6 +164,7 @@ mod tests {
pub struct Model <B: Backend> {
conv1d: Conv1d<B>,
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}
impl<B: Backend> Model <B> {
@ -180,6 +181,7 @@ mod tests {
Self {
conv1d,
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]

View File

@ -163,6 +163,7 @@ mod tests {
pub struct Model <B: Backend> {
conv2d: Conv2d<B>,
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}
impl<B: Backend> Model <B> {
@ -179,6 +180,7 @@ mod tests {
Self {
conv2d,
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]

View File

@ -160,6 +160,7 @@ mod tests {
pub struct Model <B: Backend> {
conv_transpose_2d: ConvTranspose2d<B>,
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}
impl<B: Backend> Model <B> {
@ -176,6 +177,7 @@ mod tests {
Self {
conv_transpose_2d,
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]

View File

@ -110,6 +110,7 @@ mod tests {
pub struct Model <B: Backend> {
dropout: Dropout,
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}
@ -122,6 +123,7 @@ mod tests {
Self {
dropout,
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]

View File

@ -82,6 +82,7 @@ mod tests {
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}
impl<B: Backend> Model <B> {
@ -89,6 +90,7 @@ mod tests {
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}

View File

@ -139,6 +139,7 @@ mod tests {
pub struct Model <B: Backend> {
global_avg_pool1: AdaptiveAvgPool2d,
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}
impl<B: Backend> Model <B> {
@ -150,6 +151,7 @@ mod tests {
Self {
global_avg_pool1,
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
@ -188,6 +190,7 @@ mod tests {
pub struct Model <B: Backend> {
global_avg_pool1: AdaptiveAvgPool1d,
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}
impl<B: Backend> Model <B> {
@ -199,6 +202,7 @@ mod tests {
Self {
global_avg_pool1,
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]

View File

@ -149,6 +149,7 @@ mod tests {
pub struct Model <B: Backend> {
norm: LayerNorm<B>,
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}
impl<B: Backend> Model <B> {
@ -161,6 +162,7 @@ mod tests {
Self {
norm,
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]

View File

@ -143,6 +143,7 @@ mod tests {
pub struct Model <B: Backend> {
linear: Linear<B>,
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}
impl<B: Backend> Model <B> {
@ -155,6 +156,7 @@ mod tests {
Self {
linear,
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]

View File

@ -112,6 +112,7 @@ mod tests {
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}
impl<B: Backend> Model <B> {
@ -119,6 +120,7 @@ mod tests {
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
@ -169,6 +171,7 @@ mod tests {
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}
impl<B: Backend> Model <B> {
@ -176,6 +179,7 @@ mod tests {
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}

View File

@ -126,6 +126,7 @@ mod tests {
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}
impl<B: Backend> Model <B> {
@ -133,6 +134,7 @@ mod tests {
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
@ -176,6 +178,7 @@ mod tests {
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}
impl<B: Backend> Model <B> {
@ -183,6 +186,7 @@ mod tests {
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
@ -226,6 +230,7 @@ mod tests {
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}
impl<B: Backend> Model <B> {
@ -233,6 +238,7 @@ mod tests {
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}

View File

@ -127,6 +127,7 @@ mod tests {
pub struct Model <B: Backend> {
max_pool2d: MaxPool2d,
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}
impl<B: Backend> Model <B> {
@ -141,6 +142,7 @@ mod tests {
Self {
max_pool2d,
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]

View File

@ -67,6 +67,7 @@ mod tests {
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}
impl<B: Backend> Model <B> {
@ -74,6 +75,7 @@ mod tests {
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]

View File

@ -1,12 +1,12 @@
use super::{Node, NodeCodegen};
use crate::burn::{Scope, TensorType, ToTokens, Type};
use crate::burn::{BurnImports, Scope, TensorType, ToTokens, Type};
use burn::record::PrecisionSettings;
use proc_macro2::TokenStream;
use quote::quote;
#[derive(Debug, Clone, new)]
pub struct UnsqueezeNode {
pub input: TensorType,
pub input: Type,
pub output: TensorType,
pub axes: Vec<i64>,
}
@ -17,23 +17,43 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for UnsqueezeNode {
}
fn input_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.input.clone())]
vec![self.input.clone()]
}
fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
let input = scope.tensor_use_owned(&self.input, node_position);
let output = &self.output.name;
let shape_values = &self.axes.to_tokens();
let new_dims = self.output.dim.to_tokens();
quote! {
let #output: Tensor<B, #new_dims> = #input.unsqueeze_dims(&#shape_values);
match &self.input {
Type::Tensor(tensor) => {
let input = scope.tensor_use_owned(tensor, node_position);
quote! {
let #output: Tensor<B, #new_dims> = #input.unsqueeze_dims(&#shape_values);
}
}
Type::Scalar(scalar) => {
let input = &scalar.name;
quote! {
let #output = Tensor::<B, #new_dims>::from_data([#input.elem()], &self.device).unsqueeze();
}
}
_ => panic!("Unsupported input type"),
}
}
fn into_node(self) -> Node<PS> {
Node::Unsqueeze(self)
}
fn register_imports(&self, imports: &mut BurnImports) {
match &self.input {
Type::Scalar(_) => {
imports.register("burn::tensor::ElementConversion");
}
_ => {}
}
}
}
#[cfg(test)]
@ -44,7 +64,7 @@ mod tests {
use crate::burn::{
graph::BurnGraph,
node::{test::assert_tokens, unsqueeze::UnsqueezeNode},
TensorType,
TensorType, Type,
};
#[test]
@ -52,7 +72,7 @@ mod tests {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
graph.register(UnsqueezeNode::new(
TensorType::new_float("tensor1", 3),
Type::Tensor(TensorType::new_float("tensor1", 3)),
TensorType::new_float("tensor2", 5),
[0, 4].into(),
));
@ -68,6 +88,7 @@ mod tests {
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}
impl<B: Backend> Model <B> {
@ -75,6 +96,7 @@ mod tests {
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]

View File

@ -269,22 +269,27 @@ fn unsqueeze_update_output(node: &mut Node) {
node.attrs.get("axes").cloned().map(|v| v.into_i64s())
};
// need output way up here to avoid borrowing issues
let input = match &node.inputs[0].ty {
ArgType::Tensor(tensor) => tensor.clone(),
_ => panic!("Unsqueeze: invalid output types"),
if axes.is_none() {
return;
}
let input_dim = match &node.inputs[0].ty {
ArgType::Tensor(tensor) => tensor.dim,
ArgType::Scalar(_) => 0, // treat scalar as 0-dim tensor
_ => panic!("Unsqueeze: invalid input type"),
};
let output = match &node.outputs[0].ty {
ArgType::Tensor(tensor) => tensor.clone(),
_ => panic!("Unsqueeze: invalid output types"),
let output_elem = match &node.outputs[0].ty {
ArgType::Tensor(tensor) => tensor.elem_type.clone(),
ArgType::Scalar(elem_type) => elem_type.clone(),
_ => panic!("Unsqueeze: invalid output type"),
};
if let Some(axes) = axes {
node.outputs[0].ty = ArgType::Tensor(TensorType {
dim: input.dim + axes.len(),
shape: None, // shape is calculated at runtime
..output
dim: input_dim + axes.len(),
shape: None, // shape is tracked and calculated at runtime
elem_type: output_elem,
});
}
}

View File

@ -501,7 +501,7 @@ impl OnnxGraph {
}
fn unsqueeze_conversion(node: Node) -> UnsqueezeNode {
let input = node.inputs.first().unwrap().to_tensor_type();
let input = node.inputs.first().unwrap().to_type();
let output = node.outputs.first().unwrap().to_tensor_type();
let dims = unsqueeze_config(&node);