mirror of https://github.com/tracel-ai/burn.git
ONNX support for scalar unsqueeze (#1690)
* Revert1c639c8393
1c639c8393
?diff=unified&w=0 * Refactor by @laggui * Refactor unsqueeze * Add support for scalar unsqueeze * Removed dead comment
This commit is contained in:
parent
599a20d586
commit
67ec06d5d8
|
@ -208,6 +208,67 @@ impl<B: AutodiffBackend> AutodiffModule<B> for PhantomData<B> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Container to satisfy the Module trait for types that are not modules.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Ignored<T>(pub T);
|
||||
|
||||
impl<B, T> Module<B> for Ignored<T>
|
||||
where
|
||||
B: Backend,
|
||||
T: Sync + Send + core::fmt::Debug + Clone,
|
||||
{
|
||||
type Record = ConstantRecord;
|
||||
|
||||
fn visit<V: ModuleVisitor<B>>(&self, _visitor: &mut V) {
|
||||
// Nothing to do
|
||||
}
|
||||
|
||||
fn map<M: ModuleMapper<B>>(self, _mapper: &mut M) -> Self {
|
||||
self
|
||||
}
|
||||
|
||||
fn load_record(self, _record: Self::Record) -> Self {
|
||||
self
|
||||
}
|
||||
|
||||
fn into_record(self) -> Self::Record {
|
||||
ConstantRecord::new()
|
||||
}
|
||||
|
||||
fn to_device(self, _: &<B as Backend>::Device) -> Self {
|
||||
self
|
||||
}
|
||||
|
||||
fn fork(self, _: &<B as Backend>::Device) -> Self {
|
||||
self
|
||||
}
|
||||
|
||||
fn collect_devices(&self, devices: Devices<B>) -> Devices<B> {
|
||||
devices
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: AutodiffBackend, T> AutodiffModule<B> for Ignored<T>
|
||||
where
|
||||
B: AutodiffBackend,
|
||||
T: Sync + Send + core::fmt::Debug + Clone,
|
||||
{
|
||||
type InnerModule = Ignored<T>;
|
||||
|
||||
fn valid(&self) -> Self::InnerModule {
|
||||
self.clone()
|
||||
}
|
||||
}
|
||||
|
||||
// Implement deref for Ignored
|
||||
impl<T> core::ops::Deref for Ignored<T> {
|
||||
type Target = T;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(test, feature = "std"))]
|
||||
mod tests {
|
||||
use core::marker::PhantomData;
|
||||
|
|
|
@ -1068,8 +1068,9 @@ mod tests {
|
|||
let input_shape = Shape::from([3, 4, 5]);
|
||||
let expected_shape = Shape::from([3, 4, 5, 1]);
|
||||
let input = Tensor::ones(input_shape, &device);
|
||||
let output = model.forward(input);
|
||||
assert_eq!(expected_shape, output.shape());
|
||||
let output = model.forward(input, 1.0);
|
||||
assert_eq!(expected_shape, output.0.shape());
|
||||
assert_eq!(Shape::from([1]), output.1.shape());
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -1079,8 +1080,9 @@ mod tests {
|
|||
let input_shape = Shape::from([3, 4, 5]);
|
||||
let expected_shape = Shape::from([3, 4, 5, 1]);
|
||||
let input = Tensor::ones(input_shape, &device);
|
||||
let output = model.forward(input);
|
||||
assert_eq!(expected_shape, output.shape());
|
||||
let output = model.forward(input, 1.0);
|
||||
assert_eq!(expected_shape, output.0.shape());
|
||||
assert_eq!(Shape::from([1]), output.1.shape());
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
Binary file not shown.
Binary file not shown.
|
@ -11,9 +11,10 @@ class Model(nn.Module):
|
|||
super(Model, self).__init__()
|
||||
self.axis = 3
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, scalar):
|
||||
x = torch.unsqueeze(x, self.axis)
|
||||
return x
|
||||
y = torch.unsqueeze(torch.tensor(scalar), 0)
|
||||
return x, y
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -27,21 +28,21 @@ def main():
|
|||
model.eval()
|
||||
device = torch.device("cpu")
|
||||
|
||||
test_input = torch.randn(3, 4, 5, device=device)
|
||||
test_input = (torch.randn(3, 4, 5, device=device),1.0)
|
||||
model = Model()
|
||||
|
||||
output = model.forward(test_input)
|
||||
output = model.forward(*test_input)
|
||||
|
||||
torch.onnx.export(model, (test_input), "unsqueeze_opset16.onnx", verbose=False, opset_version=16)
|
||||
torch.onnx.export(model, (test_input), "unsqueeze_opset11.onnx", verbose=False, opset_version=11)
|
||||
torch.onnx.export(model, test_input, "unsqueeze_opset16.onnx", verbose=False, opset_version=16)
|
||||
torch.onnx.export(model, test_input, "unsqueeze_opset11.onnx", verbose=False, opset_version=11)
|
||||
|
||||
print(f"Finished exporting model")
|
||||
|
||||
# Output some test data for use in the test
|
||||
print(f"Test input data of ones: {test_input}")
|
||||
print(f"Test input data shape of ones: {test_input.shape}")
|
||||
print(f"Test input data shape of ones: {test_input[0].shape}")
|
||||
# output = model.forward(test_input)
|
||||
print(f"Test output data shape: {output.shape}")
|
||||
print(f"Test output data shape: {output[0].shape}")
|
||||
|
||||
print(f"Test output: {output}")
|
||||
|
||||
|
|
|
@ -414,6 +414,7 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
|
|||
// Extend with phantom data to avoid unused generic type.
|
||||
body.extend(quote! {
|
||||
phantom: core::marker::PhantomData<B>,
|
||||
device: burn::module::Ignored<B::Device>,
|
||||
});
|
||||
|
||||
quote! {
|
||||
|
@ -447,6 +448,7 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
|
|||
Self {
|
||||
#(#fields,)*
|
||||
phantom: core::marker::PhantomData,
|
||||
device: burn::module::Ignored(device.clone()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -127,6 +127,7 @@ mod tests {
|
|||
pub struct Model <B: Backend> {
|
||||
avg_pool2d: AvgPool2d,
|
||||
phantom: core::marker::PhantomData<B>,
|
||||
device: burn::module::Ignored<B::Device>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Model <B> {
|
||||
|
@ -141,6 +142,7 @@ mod tests {
|
|||
Self {
|
||||
avg_pool2d,
|
||||
phantom: core::marker::PhantomData,
|
||||
device: burn::module::Ignored(device.clone()),
|
||||
}
|
||||
}
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
|
|
|
@ -240,6 +240,7 @@ pub(crate) mod tests {
|
|||
#[derive(Module, Debug)]
|
||||
pub struct Model<B: Backend> {
|
||||
phantom: core::marker::PhantomData<B>,
|
||||
device: burn::module::Ignored<B::Device>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Model <B> {
|
||||
|
@ -247,6 +248,7 @@ pub(crate) mod tests {
|
|||
pub fn new(device: &B::Device) -> Self {
|
||||
Self {
|
||||
phantom: core::marker::PhantomData,
|
||||
device: burn::module::Ignored(device.clone()),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -294,6 +296,7 @@ pub(crate) mod tests {
|
|||
pub struct Model <B: Backend> {
|
||||
conv2d: Conv2d<B>,
|
||||
phantom: core::marker::PhantomData<B>,
|
||||
device: burn::module::Ignored<B::Device>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Model <B> {
|
||||
|
@ -310,6 +313,7 @@ pub(crate) mod tests {
|
|||
Self {
|
||||
conv2d,
|
||||
phantom: core::marker::PhantomData,
|
||||
device: burn::module::Ignored(device.clone()),
|
||||
}
|
||||
}
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
|
@ -370,6 +374,7 @@ pub(crate) mod tests {
|
|||
pub struct Model <B: Backend> {
|
||||
conv2d: Conv2d<B>,
|
||||
phantom: core::marker::PhantomData<B>,
|
||||
device: burn::module::Ignored<B::Device>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Model <B> {
|
||||
|
@ -386,6 +391,7 @@ pub(crate) mod tests {
|
|||
Self {
|
||||
conv2d,
|
||||
phantom: core::marker::PhantomData,
|
||||
device: burn::module::Ignored(device.clone()),
|
||||
}
|
||||
}
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
|
|
|
@ -188,6 +188,7 @@ mod tests {
|
|||
pub struct Model <B: Backend> {
|
||||
norm: BatchNorm<B, 2>,
|
||||
phantom: core::marker::PhantomData<B>,
|
||||
device: burn::module::Ignored<B::Device>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Model <B> {
|
||||
|
@ -201,6 +202,7 @@ mod tests {
|
|||
Self {
|
||||
norm,
|
||||
phantom: core::marker::PhantomData,
|
||||
device: burn::module::Ignored(device.clone()),
|
||||
}
|
||||
}
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
|
|
|
@ -354,6 +354,7 @@ mod tests {
|
|||
#[derive(Module, Debug)]
|
||||
pub struct Model<B: Backend> {
|
||||
phantom: core::marker::PhantomData<B>,
|
||||
device: burn::module::Ignored<B::Device>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Model <B> {
|
||||
|
@ -361,6 +362,7 @@ mod tests {
|
|||
pub fn new(device: &B::Device) -> Self {
|
||||
Self {
|
||||
phantom: core::marker::PhantomData,
|
||||
device: burn::module::Ignored(device.clone()),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -78,6 +78,7 @@ mod tests {
|
|||
#[derive(Module, Debug)]
|
||||
pub struct Model<B: Backend> {
|
||||
phantom: core::marker::PhantomData<B>,
|
||||
device: burn::module::Ignored<B::Device>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Model <B> {
|
||||
|
@ -85,6 +86,7 @@ mod tests {
|
|||
pub fn new(device: &B::Device) -> Self {
|
||||
Self {
|
||||
phantom: core::marker::PhantomData,
|
||||
device: burn::module::Ignored(device.clone()),
|
||||
}
|
||||
}
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
|
@ -121,6 +123,7 @@ mod tests {
|
|||
#[derive(Module, Debug)]
|
||||
pub struct Model<B: Backend> {
|
||||
phantom: core::marker::PhantomData<B>,
|
||||
device: burn::module::Ignored<B::Device>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Model <B> {
|
||||
|
@ -128,6 +131,7 @@ mod tests {
|
|||
pub fn new(device: &B::Device) -> Self {
|
||||
Self {
|
||||
phantom: core::marker::PhantomData,
|
||||
device: burn::module::Ignored(device.clone()),
|
||||
}
|
||||
}
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
|
@ -164,6 +168,7 @@ mod tests {
|
|||
#[derive(Module, Debug)]
|
||||
pub struct Model<B: Backend> {
|
||||
phantom: core::marker::PhantomData<B>,
|
||||
device: burn::module::Ignored<B::Device>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Model <B> {
|
||||
|
@ -171,6 +176,7 @@ mod tests {
|
|||
pub fn new(device: &B::Device) -> Self {
|
||||
Self {
|
||||
phantom: core::marker::PhantomData,
|
||||
device: burn::module::Ignored(device.clone()),
|
||||
}
|
||||
}
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
|
|
|
@ -82,6 +82,7 @@ mod tests {
|
|||
#[derive(Module, Debug)]
|
||||
pub struct Model<B: Backend> {
|
||||
phantom: core::marker::PhantomData<B>,
|
||||
device: burn::module::Ignored<B::Device>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Model <B> {
|
||||
|
@ -89,6 +90,7 @@ mod tests {
|
|||
pub fn new(device: &B::Device) -> Self {
|
||||
Self {
|
||||
phantom: core::marker::PhantomData,
|
||||
device: burn::module::Ignored(device.clone()),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -164,6 +164,7 @@ mod tests {
|
|||
pub struct Model <B: Backend> {
|
||||
conv1d: Conv1d<B>,
|
||||
phantom: core::marker::PhantomData<B>,
|
||||
device: burn::module::Ignored<B::Device>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Model <B> {
|
||||
|
@ -180,6 +181,7 @@ mod tests {
|
|||
Self {
|
||||
conv1d,
|
||||
phantom: core::marker::PhantomData,
|
||||
device: burn::module::Ignored(device.clone()),
|
||||
}
|
||||
}
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
|
|
|
@ -163,6 +163,7 @@ mod tests {
|
|||
pub struct Model <B: Backend> {
|
||||
conv2d: Conv2d<B>,
|
||||
phantom: core::marker::PhantomData<B>,
|
||||
device: burn::module::Ignored<B::Device>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Model <B> {
|
||||
|
@ -179,6 +180,7 @@ mod tests {
|
|||
Self {
|
||||
conv2d,
|
||||
phantom: core::marker::PhantomData,
|
||||
device: burn::module::Ignored(device.clone()),
|
||||
}
|
||||
}
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
|
|
|
@ -160,6 +160,7 @@ mod tests {
|
|||
pub struct Model <B: Backend> {
|
||||
conv_transpose_2d: ConvTranspose2d<B>,
|
||||
phantom: core::marker::PhantomData<B>,
|
||||
device: burn::module::Ignored<B::Device>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Model <B> {
|
||||
|
@ -176,6 +177,7 @@ mod tests {
|
|||
Self {
|
||||
conv_transpose_2d,
|
||||
phantom: core::marker::PhantomData,
|
||||
device: burn::module::Ignored(device.clone()),
|
||||
}
|
||||
}
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
|
|
|
@ -110,6 +110,7 @@ mod tests {
|
|||
pub struct Model <B: Backend> {
|
||||
dropout: Dropout,
|
||||
phantom: core::marker::PhantomData<B>,
|
||||
device: burn::module::Ignored<B::Device>,
|
||||
|
||||
}
|
||||
|
||||
|
@ -122,6 +123,7 @@ mod tests {
|
|||
Self {
|
||||
dropout,
|
||||
phantom: core::marker::PhantomData,
|
||||
device: burn::module::Ignored(device.clone()),
|
||||
}
|
||||
}
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
|
|
|
@ -82,6 +82,7 @@ mod tests {
|
|||
#[derive(Module, Debug)]
|
||||
pub struct Model<B: Backend> {
|
||||
phantom: core::marker::PhantomData<B>,
|
||||
device: burn::module::Ignored<B::Device>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Model <B> {
|
||||
|
@ -89,6 +90,7 @@ mod tests {
|
|||
pub fn new(device: &B::Device) -> Self {
|
||||
Self {
|
||||
phantom: core::marker::PhantomData,
|
||||
device: burn::module::Ignored(device.clone()),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -139,6 +139,7 @@ mod tests {
|
|||
pub struct Model <B: Backend> {
|
||||
global_avg_pool1: AdaptiveAvgPool2d,
|
||||
phantom: core::marker::PhantomData<B>,
|
||||
device: burn::module::Ignored<B::Device>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Model <B> {
|
||||
|
@ -150,6 +151,7 @@ mod tests {
|
|||
Self {
|
||||
global_avg_pool1,
|
||||
phantom: core::marker::PhantomData,
|
||||
device: burn::module::Ignored(device.clone()),
|
||||
}
|
||||
}
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
|
@ -188,6 +190,7 @@ mod tests {
|
|||
pub struct Model <B: Backend> {
|
||||
global_avg_pool1: AdaptiveAvgPool1d,
|
||||
phantom: core::marker::PhantomData<B>,
|
||||
device: burn::module::Ignored<B::Device>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Model <B> {
|
||||
|
@ -199,6 +202,7 @@ mod tests {
|
|||
Self {
|
||||
global_avg_pool1,
|
||||
phantom: core::marker::PhantomData,
|
||||
device: burn::module::Ignored(device.clone()),
|
||||
}
|
||||
}
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
|
|
|
@ -149,6 +149,7 @@ mod tests {
|
|||
pub struct Model <B: Backend> {
|
||||
norm: LayerNorm<B>,
|
||||
phantom: core::marker::PhantomData<B>,
|
||||
device: burn::module::Ignored<B::Device>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Model <B> {
|
||||
|
@ -161,6 +162,7 @@ mod tests {
|
|||
Self {
|
||||
norm,
|
||||
phantom: core::marker::PhantomData,
|
||||
device: burn::module::Ignored(device.clone()),
|
||||
}
|
||||
}
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
|
|
|
@ -143,6 +143,7 @@ mod tests {
|
|||
pub struct Model <B: Backend> {
|
||||
linear: Linear<B>,
|
||||
phantom: core::marker::PhantomData<B>,
|
||||
device: burn::module::Ignored<B::Device>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Model <B> {
|
||||
|
@ -155,6 +156,7 @@ mod tests {
|
|||
Self {
|
||||
linear,
|
||||
phantom: core::marker::PhantomData,
|
||||
device: burn::module::Ignored(device.clone()),
|
||||
}
|
||||
}
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
|
|
|
@ -112,6 +112,7 @@ mod tests {
|
|||
#[derive(Module, Debug)]
|
||||
pub struct Model<B: Backend> {
|
||||
phantom: core::marker::PhantomData<B>,
|
||||
device: burn::module::Ignored<B::Device>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Model <B> {
|
||||
|
@ -119,6 +120,7 @@ mod tests {
|
|||
pub fn new(device: &B::Device) -> Self {
|
||||
Self {
|
||||
phantom: core::marker::PhantomData,
|
||||
device: burn::module::Ignored(device.clone()),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -169,6 +171,7 @@ mod tests {
|
|||
#[derive(Module, Debug)]
|
||||
pub struct Model<B: Backend> {
|
||||
phantom: core::marker::PhantomData<B>,
|
||||
device: burn::module::Ignored<B::Device>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Model <B> {
|
||||
|
@ -176,6 +179,7 @@ mod tests {
|
|||
pub fn new(device: &B::Device) -> Self {
|
||||
Self {
|
||||
phantom: core::marker::PhantomData,
|
||||
device: burn::module::Ignored(device.clone()),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -126,6 +126,7 @@ mod tests {
|
|||
#[derive(Module, Debug)]
|
||||
pub struct Model<B: Backend> {
|
||||
phantom: core::marker::PhantomData<B>,
|
||||
device: burn::module::Ignored<B::Device>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Model <B> {
|
||||
|
@ -133,6 +134,7 @@ mod tests {
|
|||
pub fn new(device: &B::Device) -> Self {
|
||||
Self {
|
||||
phantom: core::marker::PhantomData,
|
||||
device: burn::module::Ignored(device.clone()),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -176,6 +178,7 @@ mod tests {
|
|||
#[derive(Module, Debug)]
|
||||
pub struct Model<B: Backend> {
|
||||
phantom: core::marker::PhantomData<B>,
|
||||
device: burn::module::Ignored<B::Device>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Model <B> {
|
||||
|
@ -183,6 +186,7 @@ mod tests {
|
|||
pub fn new(device: &B::Device) -> Self {
|
||||
Self {
|
||||
phantom: core::marker::PhantomData,
|
||||
device: burn::module::Ignored(device.clone()),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -226,6 +230,7 @@ mod tests {
|
|||
#[derive(Module, Debug)]
|
||||
pub struct Model<B: Backend> {
|
||||
phantom: core::marker::PhantomData<B>,
|
||||
device: burn::module::Ignored<B::Device>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Model <B> {
|
||||
|
@ -233,6 +238,7 @@ mod tests {
|
|||
pub fn new(device: &B::Device) -> Self {
|
||||
Self {
|
||||
phantom: core::marker::PhantomData,
|
||||
device: burn::module::Ignored(device.clone()),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -127,6 +127,7 @@ mod tests {
|
|||
pub struct Model <B: Backend> {
|
||||
max_pool2d: MaxPool2d,
|
||||
phantom: core::marker::PhantomData<B>,
|
||||
device: burn::module::Ignored<B::Device>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Model <B> {
|
||||
|
@ -141,6 +142,7 @@ mod tests {
|
|||
Self {
|
||||
max_pool2d,
|
||||
phantom: core::marker::PhantomData,
|
||||
device: burn::module::Ignored(device.clone()),
|
||||
}
|
||||
}
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
|
|
|
@ -67,6 +67,7 @@ mod tests {
|
|||
#[derive(Module, Debug)]
|
||||
pub struct Model<B: Backend> {
|
||||
phantom: core::marker::PhantomData<B>,
|
||||
device: burn::module::Ignored<B::Device>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Model <B> {
|
||||
|
@ -74,6 +75,7 @@ mod tests {
|
|||
pub fn new(device: &B::Device) -> Self {
|
||||
Self {
|
||||
phantom: core::marker::PhantomData,
|
||||
device: burn::module::Ignored(device.clone()),
|
||||
}
|
||||
}
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
use super::{Node, NodeCodegen};
|
||||
use crate::burn::{Scope, TensorType, ToTokens, Type};
|
||||
use crate::burn::{BurnImports, Scope, TensorType, ToTokens, Type};
|
||||
use burn::record::PrecisionSettings;
|
||||
use proc_macro2::TokenStream;
|
||||
use quote::quote;
|
||||
|
||||
#[derive(Debug, Clone, new)]
|
||||
pub struct UnsqueezeNode {
|
||||
pub input: TensorType,
|
||||
pub input: Type,
|
||||
pub output: TensorType,
|
||||
pub axes: Vec<i64>,
|
||||
}
|
||||
|
@ -17,23 +17,43 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for UnsqueezeNode {
|
|||
}
|
||||
|
||||
fn input_types(&self) -> Vec<Type> {
|
||||
vec![Type::Tensor(self.input.clone())]
|
||||
vec![self.input.clone()]
|
||||
}
|
||||
|
||||
fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
|
||||
let input = scope.tensor_use_owned(&self.input, node_position);
|
||||
let output = &self.output.name;
|
||||
let shape_values = &self.axes.to_tokens();
|
||||
let new_dims = self.output.dim.to_tokens();
|
||||
|
||||
quote! {
|
||||
let #output: Tensor<B, #new_dims> = #input.unsqueeze_dims(&#shape_values);
|
||||
match &self.input {
|
||||
Type::Tensor(tensor) => {
|
||||
let input = scope.tensor_use_owned(tensor, node_position);
|
||||
quote! {
|
||||
let #output: Tensor<B, #new_dims> = #input.unsqueeze_dims(&#shape_values);
|
||||
}
|
||||
}
|
||||
Type::Scalar(scalar) => {
|
||||
let input = &scalar.name;
|
||||
quote! {
|
||||
let #output = Tensor::<B, #new_dims>::from_data([#input.elem()], &self.device).unsqueeze();
|
||||
}
|
||||
}
|
||||
_ => panic!("Unsupported input type"),
|
||||
}
|
||||
}
|
||||
|
||||
fn into_node(self) -> Node<PS> {
|
||||
Node::Unsqueeze(self)
|
||||
}
|
||||
|
||||
fn register_imports(&self, imports: &mut BurnImports) {
|
||||
match &self.input {
|
||||
Type::Scalar(_) => {
|
||||
imports.register("burn::tensor::ElementConversion");
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
@ -44,7 +64,7 @@ mod tests {
|
|||
use crate::burn::{
|
||||
graph::BurnGraph,
|
||||
node::{test::assert_tokens, unsqueeze::UnsqueezeNode},
|
||||
TensorType,
|
||||
TensorType, Type,
|
||||
};
|
||||
|
||||
#[test]
|
||||
|
@ -52,7 +72,7 @@ mod tests {
|
|||
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
|
||||
|
||||
graph.register(UnsqueezeNode::new(
|
||||
TensorType::new_float("tensor1", 3),
|
||||
Type::Tensor(TensorType::new_float("tensor1", 3)),
|
||||
TensorType::new_float("tensor2", 5),
|
||||
[0, 4].into(),
|
||||
));
|
||||
|
@ -68,6 +88,7 @@ mod tests {
|
|||
#[derive(Module, Debug)]
|
||||
pub struct Model<B: Backend> {
|
||||
phantom: core::marker::PhantomData<B>,
|
||||
device: burn::module::Ignored<B::Device>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Model <B> {
|
||||
|
@ -75,6 +96,7 @@ mod tests {
|
|||
pub fn new(device: &B::Device) -> Self {
|
||||
Self {
|
||||
phantom: core::marker::PhantomData,
|
||||
device: burn::module::Ignored(device.clone()),
|
||||
}
|
||||
}
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
|
|
|
@ -269,22 +269,27 @@ fn unsqueeze_update_output(node: &mut Node) {
|
|||
node.attrs.get("axes").cloned().map(|v| v.into_i64s())
|
||||
};
|
||||
|
||||
// need output way up here to avoid borrowing issues
|
||||
let input = match &node.inputs[0].ty {
|
||||
ArgType::Tensor(tensor) => tensor.clone(),
|
||||
_ => panic!("Unsqueeze: invalid output types"),
|
||||
if axes.is_none() {
|
||||
return;
|
||||
}
|
||||
|
||||
let input_dim = match &node.inputs[0].ty {
|
||||
ArgType::Tensor(tensor) => tensor.dim,
|
||||
ArgType::Scalar(_) => 0, // treat scalar as 0-dim tensor
|
||||
_ => panic!("Unsqueeze: invalid input type"),
|
||||
};
|
||||
|
||||
let output = match &node.outputs[0].ty {
|
||||
ArgType::Tensor(tensor) => tensor.clone(),
|
||||
_ => panic!("Unsqueeze: invalid output types"),
|
||||
let output_elem = match &node.outputs[0].ty {
|
||||
ArgType::Tensor(tensor) => tensor.elem_type.clone(),
|
||||
ArgType::Scalar(elem_type) => elem_type.clone(),
|
||||
_ => panic!("Unsqueeze: invalid output type"),
|
||||
};
|
||||
|
||||
if let Some(axes) = axes {
|
||||
node.outputs[0].ty = ArgType::Tensor(TensorType {
|
||||
dim: input.dim + axes.len(),
|
||||
shape: None, // shape is calculated at runtime
|
||||
..output
|
||||
dim: input_dim + axes.len(),
|
||||
shape: None, // shape is tracked and calculated at runtime
|
||||
elem_type: output_elem,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
|
@ -501,7 +501,7 @@ impl OnnxGraph {
|
|||
}
|
||||
|
||||
fn unsqueeze_conversion(node: Node) -> UnsqueezeNode {
|
||||
let input = node.inputs.first().unwrap().to_tensor_type();
|
||||
let input = node.inputs.first().unwrap().to_type();
|
||||
let output = node.outputs.first().unwrap().to_tensor_type();
|
||||
let dims = unsqueeze_config(&node);
|
||||
|
||||
|
|
Loading…
Reference in New Issue