mirror of https://github.com/tracel-ai/burn.git
Full support for ONNX scalar operators and Constants (#578)
This commit is contained in:
parent
ca9a8808d9
commit
1554a3c898
|
@ -11,6 +11,7 @@ members = [
|
|||
"burn-dataset",
|
||||
"burn-derive",
|
||||
"burn-import",
|
||||
"burn-import/onnx-tests",
|
||||
"burn-ndarray",
|
||||
"burn-no-std-tests",
|
||||
"burn-tch",
|
||||
|
@ -29,6 +30,7 @@ dashmap = "5.4.0"
|
|||
dirs = "5.0.1"
|
||||
fake = "2.6.1"
|
||||
flate2 = "1.0.26"
|
||||
float-cmp = "0.9.0"
|
||||
gix-tempfile = {version = "7.0.0", features = ["signals"]}
|
||||
hashbrown = "0.14.0"
|
||||
indicatif = "0.17.5"
|
||||
|
|
|
@ -23,6 +23,12 @@ impl<T: Clone> Param<T> {
|
|||
pub fn val(&self) -> T {
|
||||
self.value.clone()
|
||||
}
|
||||
|
||||
/// Execute the given function on the inner value.
|
||||
pub fn map<F: FnOnce(T) -> T>(mut self, func: F) -> Self {
|
||||
self.value = func(self.value);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> core::ops::Deref for Param<T> {
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
use core::marker::PhantomData;
|
||||
|
||||
use crate::{
|
||||
self as burn,
|
||||
module::{ADModule, Module, ModuleMapper, ModuleVisitor},
|
||||
|
@ -135,12 +137,10 @@ impl<const D: usize, B: Backend> Module<B> for Tensor<B, D> {
|
|||
}
|
||||
|
||||
fn into_record(self) -> Self::Record {
|
||||
// Treat as a constant and do not record
|
||||
ConstantRecord::new()
|
||||
ConstantRecord
|
||||
}
|
||||
|
||||
fn load_record(self, _record: Self::Record) -> Self {
|
||||
// Treat as a constant and do not load
|
||||
self
|
||||
}
|
||||
}
|
||||
|
@ -153,15 +153,49 @@ impl<const D: usize, B: ADBackend> ADModule<B> for Tensor<B, D> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Module<B> for PhantomData<B> {
|
||||
type Record = ConstantRecord;
|
||||
|
||||
fn visit<V: ModuleVisitor<B>>(&self, _visitor: &mut V) {
|
||||
// Nothing to do
|
||||
}
|
||||
|
||||
fn map<M: ModuleMapper<B>>(self, _mapper: &mut M) -> Self {
|
||||
self
|
||||
}
|
||||
|
||||
fn load_record(self, _record: Self::Record) -> Self {
|
||||
self
|
||||
}
|
||||
|
||||
fn into_record(self) -> Self::Record {
|
||||
ConstantRecord::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: ADBackend> ADModule<B> for PhantomData<B> {
|
||||
type InnerModule = PhantomData<B::InnerBackend>;
|
||||
|
||||
fn valid(&self) -> Self::InnerModule {
|
||||
PhantomData
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(test, feature = "std"))]
|
||||
mod tests {
|
||||
use core::marker::PhantomData;
|
||||
|
||||
use burn_tensor::backend::Backend;
|
||||
use burn_tensor::Tensor;
|
||||
|
||||
use crate::module::Module;
|
||||
use crate::TestBackend;
|
||||
use crate::{
|
||||
record::{BinBytesRecorder, FullPrecisionSettings, Recorder},
|
||||
TestADBackend,
|
||||
};
|
||||
use burn::module::Module;
|
||||
|
||||
use crate as burn;
|
||||
|
||||
#[test]
|
||||
fn tensor_load_record_setting() {
|
||||
|
@ -185,4 +219,16 @@ mod tests {
|
|||
assert!(!no_grad_is_require_grad);
|
||||
assert!(!with_default_is_require_grad);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_module_with_phantom() {
|
||||
#[derive(Module, Debug, new)]
|
||||
struct EmptyModule<B: Backend> {
|
||||
_phantom: PhantomData<B>,
|
||||
}
|
||||
|
||||
let _module = EmptyModule::<TestBackend>::new();
|
||||
|
||||
assert_eq!(core::mem::size_of::<EmptyModule<TestBackend>>(), 0);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,3 +15,5 @@ pub use settings::*;
|
|||
mod file;
|
||||
#[cfg(feature = "std")]
|
||||
pub use file::*;
|
||||
|
||||
pub use primitive::ParamSerde;
|
||||
|
|
|
@ -32,7 +32,7 @@ List taken from [here](https://github.com/onnx/onnx/blob/main/docs/Operators.md)
|
|||
- [ ] BitwiseOr
|
||||
- [ ] BitwiseXor
|
||||
- [ ] BlackmanWindow
|
||||
- [ ] Cast
|
||||
- [x] Cast
|
||||
- [ ] CastLike
|
||||
- [ ] Ceil
|
||||
- [ ] Celu
|
||||
|
@ -40,9 +40,9 @@ List taken from [here](https://github.com/onnx/onnx/blob/main/docs/Operators.md)
|
|||
- [ ] Clip
|
||||
- [ ] Col
|
||||
- [ ] Compress
|
||||
- [ ] Concat
|
||||
- [x] Concat
|
||||
- [ ] ConcatFromSequence
|
||||
- [ ] Constant
|
||||
- [x] Constant
|
||||
- [ ] ConstantOfShape
|
||||
- [ ] Conv
|
||||
- [ ] Conv1d
|
||||
|
|
|
@ -0,0 +1,13 @@
|
|||
[package]
|
||||
name = "onnx-tests"
|
||||
version = "0.9.0"
|
||||
edition = "2021"
|
||||
|
||||
[dev-dependencies]
|
||||
burn = { path = "../../burn" }
|
||||
burn-ndarray = { path = "../../burn-ndarray" }
|
||||
serde = { workspace = true }
|
||||
float-cmp = { workspace = true }
|
||||
|
||||
[build-dependencies]
|
||||
burn-import = { path = "../" }
|
|
@ -0,0 +1,32 @@
|
|||
# ONNX Tests
|
||||
|
||||
This crate contains ONNX models that are utilized in testing the conversion of ONNX to Burn source
|
||||
code through the `burn-import` crate. The tests are designed as end-to-end tests, ensuring that ONNX
|
||||
models are accurately converted into Burn source code. Of utmost importance is verifying that the
|
||||
converted Burn source code compiles without errors and produces the same output as the original ONNX
|
||||
model.
|
||||
|
||||
Here is the directory structure of this crate:
|
||||
|
||||
- `tests/<model>`: This directory contains the ONNX model and the Python script to generate it.
|
||||
- `tests/<model>/<model>.onnx`: The ONNX model is generated by the script.
|
||||
- `tests/<model>/<model>.py`: This is the Python script responsible for generating the ONNX model
|
||||
using PyTorch.
|
||||
- `tests/onnx_tests.rs`: This is the main test file, where all the tests are contained.
|
||||
- `build.rs`: This build script generates the ONNX models and is executed by `cargo test` before
|
||||
running the actual tests.
|
||||
|
||||
## Adding new tests
|
||||
|
||||
Here are the steps to add a new test:
|
||||
|
||||
1. Add your Python script to the `tests/<model>` directory. Refer to existing scripts for examples.
|
||||
2. Run your Python script to generate the ONNX model and inspect the output of the model with the
|
||||
test data. Use the inputs and outputs in your test.
|
||||
3. Make sure the ONNX output contains the desired operators by verifying with the
|
||||
[Netron](https://github.com/lutzroeder/netron) app. Sometimes PyTorch will optimize the model and
|
||||
remove operators that are not necessary for the model to run. If this happens, you can disable
|
||||
optimization by setting `torch.onnx.export(..., do_constant_folding=False)`.
|
||||
4. Add an entry to the `build.rs` file to account for the generation of the new ONNX model.
|
||||
5. Include a test in `tests/onnx_tests.rs` to test the new ONNX model.
|
||||
6. Run `cargo test` to ensure your test passes.
|
|
@ -0,0 +1,19 @@
|
|||
use burn_import::onnx::ModelGen;
|
||||
|
||||
fn main() {
|
||||
// Re-run this build script if the onnx-tests directory changes.
|
||||
println!("cargo:rerun-if-changed=tests");
|
||||
|
||||
// Add onnx models.
|
||||
ModelGen::new()
|
||||
.input("tests/add/add.onnx")
|
||||
.input("tests/sub/sub.onnx")
|
||||
.input("tests/mul/mul.onnx")
|
||||
.input("tests/div/div.onnx")
|
||||
.input("tests/concat/concat.onnx")
|
||||
.input("tests/conv2d/conv2d.onnx")
|
||||
.out_dir("model/")
|
||||
.run_from_script();
|
||||
|
||||
// panic!("Purposefully failing build to output logs.");
|
||||
}
|
|
@ -0,0 +1 @@
|
|||
|
Binary file not shown.
|
@ -0,0 +1,57 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# used to generate model: onnx-tests/tests/add/add.onnx
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
# Declare a constant float tensor with ones
|
||||
self.a = torch.ones(1, 1, 1, 4)
|
||||
|
||||
# Declare a scalar
|
||||
self.b = 5.0
|
||||
super(Model, self).__init__()
|
||||
|
||||
def forward(self, x, k):
|
||||
|
||||
# Add a tensor input and a constant tensor
|
||||
x = x + self.a
|
||||
|
||||
# Add a scalar constant and a scalar input
|
||||
d = self.b + k
|
||||
|
||||
# Add a tensor and a scalar
|
||||
x = x + d
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
# Export to onnx
|
||||
model = Model()
|
||||
model.eval()
|
||||
device = torch.device("cpu")
|
||||
onnx_name = "add.onnx"
|
||||
dummy_input = torch.randn(1, 2, 3, 4, device=device)
|
||||
|
||||
scalar = 2.0
|
||||
|
||||
torch.onnx.export(model, (dummy_input, scalar), onnx_name,
|
||||
verbose=False, opset_version=16)
|
||||
|
||||
print("Finished exporting model to {}".format(onnx_name))
|
||||
|
||||
# Output some test data for use in the test
|
||||
test_input = torch.tensor([[[[1.0, 2.0, 3.0, 4.0]]]])
|
||||
|
||||
print("Test input data: {}, {}".format(test_input, scalar))
|
||||
output = model.forward(test_input, scalar)
|
||||
print("Test output data: {}".format(output))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -1,4 +1,4 @@
|
|||
pytorch2.0.1:£
|
||||
pytorch2.0.1:¡
|
||||
P
|
||||
onnx::Concat_0
|
||||
onnx::Concat_0/Concat_output_0/Concat"Concat*
|
||||
|
@ -9,16 +9,16 @@ P
|
|||
/Concat_output_0
|
||||
/Concat_output_0
|
||||
/Concat_output_02 /Concat_1"Concat*
|
||||
axis torch_jitZ)
|
||||
onnx::Concat_0
|
||||
|
||||
axis torch_jitZ(
|
||||
onnx::Concat_0
|
||||
|
||||
|
||||
€
|
||||
|
||||
b
|
||||
2
|
||||
|
||||
|
||||
|
||||
b
|
||||
2
|
||||
|
||||
|
||||
€
|
||||
|
||||
B
|
||||
|
||||
|
||||
B
|
19
burn-import/tests/data/concat/concat.py → burn-import/onnx-tests/tests/concat/concat.py
Normal file → Executable file
19
burn-import/tests/data/concat/concat.py → burn-import/onnx-tests/tests/concat/concat.py
Normal file → Executable file
|
@ -1,9 +1,9 @@
|
|||
# used to generate model: burn-import/tests/data/conv2d/conv2d.onnx
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# used to generate model: onnx-tests/tests/concat/concat.onnx
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import onnx
|
||||
from onnxoptimizer import optimize
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -24,9 +24,20 @@ def main():
|
|||
model.eval()
|
||||
device = torch.device("cpu")
|
||||
onnx_name = "concat.onnx"
|
||||
dummy_input = torch.randn(1,256,13,13, device=device)
|
||||
dummy_input = torch.randn(1,2,3,5, device=device)
|
||||
torch.onnx.export(model, dummy_input, onnx_name,
|
||||
verbose=False, opset_version=16)
|
||||
|
||||
print("Finished exporting model to {}".format(onnx_name))
|
||||
|
||||
# Output some test data for use in the test
|
||||
test_input = torch.randn(1,2,3,5, device=device)
|
||||
print("Test input data shape: {}".format(test_input.shape))
|
||||
output = model.forward(test_input)
|
||||
|
||||
print("Test output data shape: {}".format(output.shape))
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Binary file not shown.
|
@ -0,0 +1,43 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# used to generate model: onnx-tests/tests/conv2d/conv2d.onnx
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
self.conv1 = nn.Conv2d(4, 6, (3, 5), groups = 2, stride=(2, 1), padding=(4, 2), dilation=(3, 1))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
return x
|
||||
|
||||
def main():
|
||||
|
||||
# Export to onnx
|
||||
model = Model()
|
||||
model.eval()
|
||||
device = torch.device("cpu")
|
||||
|
||||
file_name = "conv2d.onnx"
|
||||
test_input = torch.ones(2, 4, 10, 15, device=device)
|
||||
torch.onnx.export(model, test_input, file_name,
|
||||
verbose=False, opset_version=16)
|
||||
|
||||
print("Finished exporting model to {}".format(file_name))
|
||||
|
||||
# Output some test data for use in the test
|
||||
print("Test input data shape of ones: {}".format(test_input.shape))
|
||||
output = model.forward(test_input)
|
||||
print("Test output data shape: {}".format(output.shape))
|
||||
|
||||
sum = output.sum().item()
|
||||
|
||||
print("Test output sum: {}".format(sum))
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Binary file not shown.
|
@ -0,0 +1,47 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# used to generate model: onnx-tests/tests/add/add.onnx
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
|
||||
def forward(self, x, k, m):
|
||||
|
||||
a = k / m
|
||||
|
||||
x = x / a
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
# Export to onnx
|
||||
model = Model()
|
||||
model.eval()
|
||||
device = torch.device("cpu")
|
||||
onnx_name = "div.onnx"
|
||||
dummy_input = torch.randn(1, 2, 3, 4, device=device)
|
||||
|
||||
scalar1, scalar2 = 9.0, 3.0
|
||||
|
||||
torch.onnx.export(model, (dummy_input, scalar1, scalar2), onnx_name,
|
||||
verbose=False, opset_version=16)
|
||||
|
||||
print("Finished exporting model to {}".format(onnx_name))
|
||||
|
||||
# Output some test data for use in the test
|
||||
test_input = torch.tensor([[[[3.0, 6.0, 6.0, 9.0]]]])
|
||||
|
||||
print("Test input data: {}, {}, {}".format(test_input, scalar1, scalar2))
|
||||
output = model.forward(test_input, scalar1, scalar2)
|
||||
print("Test output data: {}".format(output))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Binary file not shown.
|
@ -0,0 +1,57 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# used to generate model: onnx-tests/tests/add/add.onnx
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
# Declare a constant float tensor
|
||||
self.a = torch.full((1, 1, 1, 4), 3.0)
|
||||
|
||||
# Declare a scalar
|
||||
self.b = 7.0
|
||||
super(Model, self).__init__()
|
||||
|
||||
def forward(self, x, k):
|
||||
|
||||
# Multiply the input by the constant tensor
|
||||
x = x * self.a
|
||||
|
||||
# Multiply the input scalar by the constant scalar
|
||||
d = k * self.b
|
||||
|
||||
# Multiply the result of the previous multiplications
|
||||
x = x * d
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
# Export to onnx
|
||||
model = Model()
|
||||
model.eval()
|
||||
device = torch.device("cpu")
|
||||
onnx_name = "mul.onnx"
|
||||
dummy_input = torch.randn(1, 2, 3, 4, device=device)
|
||||
|
||||
scalar = 6.0
|
||||
|
||||
torch.onnx.export(model, (dummy_input, scalar), onnx_name,
|
||||
verbose=False, opset_version=16)
|
||||
|
||||
print("Finished exporting model to {}".format(onnx_name))
|
||||
|
||||
# Output some test data for use in the test
|
||||
test_input = torch.tensor([[[[1.0, 2.0, 3.0, 4.0]]]])
|
||||
|
||||
print("Test input data: {}, {}".format(test_input, scalar))
|
||||
output = model.forward(test_input, scalar)
|
||||
print("Test output data: {}".format(output))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,128 @@
|
|||
pub mod add {
|
||||
include!(concat!(env!("OUT_DIR"), "/model/add.rs"));
|
||||
}
|
||||
|
||||
pub mod sub {
|
||||
include!(concat!(env!("OUT_DIR"), "/model/sub.rs"));
|
||||
}
|
||||
|
||||
pub mod mul {
|
||||
include!(concat!(env!("OUT_DIR"), "/model/mul.rs"));
|
||||
}
|
||||
|
||||
pub mod div {
|
||||
include!(concat!(env!("OUT_DIR"), "/model/div.rs"));
|
||||
}
|
||||
|
||||
pub mod concat {
|
||||
include!(concat!(env!("OUT_DIR"), "/model/concat.rs"));
|
||||
}
|
||||
|
||||
pub mod conv2d {
|
||||
include!(concat!(env!("OUT_DIR"), "/model/conv2d.rs"));
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use burn::tensor::{Data, Shape, Tensor};
|
||||
|
||||
use float_cmp::ApproxEq;
|
||||
|
||||
type Backend = burn_ndarray::NdArrayBackend<f32>;
|
||||
|
||||
#[test]
|
||||
fn add_scalar_to_tensor_and_tensor_to_tensor() {
|
||||
// Initialize the model with weights (loaded from the exported file)
|
||||
let model: add::Model<Backend> = add::Model::default();
|
||||
|
||||
// Run the model
|
||||
let input = Tensor::<Backend, 4>::from_floats([[[[1., 2., 3., 4.]]]]);
|
||||
let scalar = 2f64;
|
||||
let output = model.forward(input, scalar);
|
||||
let expected = Data::from([[[[9., 10., 11., 12.]]]]);
|
||||
|
||||
assert_eq!(output.to_data(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sub_scalar_from_tensor_and_tensor_from_tensor() {
|
||||
// Initialize the model with weights (loaded from the exported file)
|
||||
let model: sub::Model<Backend> = sub::Model::default();
|
||||
|
||||
// Run the model
|
||||
let input = Tensor::<Backend, 4>::from_floats([[[[1., 2., 3., 4.]]]]);
|
||||
let scalar = 3.0f64;
|
||||
let output = model.forward(input, scalar);
|
||||
let expected = Data::from([[[[6., 7., 8., 9.]]]]);
|
||||
|
||||
assert_eq!(output.to_data(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mul_scalar_with_tensor_and_tensor_with_tensor() {
|
||||
// Initialize the model with weights (loaded from the exported file)
|
||||
let model: mul::Model<Backend> = mul::Model::default();
|
||||
|
||||
// Run the model
|
||||
let input = Tensor::<Backend, 4>::from_floats([[[[1., 2., 3., 4.]]]]);
|
||||
let scalar = 6.0f64;
|
||||
let output = model.forward(input, scalar);
|
||||
let expected = Data::from([[[[126., 252., 378., 504.]]]]);
|
||||
|
||||
assert_eq!(output.to_data(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn div_tensor_by_scalar_and_tensor_by_tensor() {
|
||||
// Initialize the model without weights (because the exported file does not contain them)
|
||||
let model: div::Model<Backend> = div::Model::new();
|
||||
|
||||
// Run the model
|
||||
let input = Tensor::<Backend, 4>::from_floats([[[[3., 6., 6., 9.]]]]);
|
||||
let scalar1 = 9.0f64;
|
||||
let scalar2 = 3.0f64;
|
||||
let output = model.forward(input, scalar1, scalar2);
|
||||
let expected = Data::from([[[[1., 2., 2., 3.]]]]);
|
||||
|
||||
assert_eq!(output.to_data(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn concat_tensors() {
|
||||
// Initialize the model
|
||||
let model: concat::Model<Backend> = concat::Model::new();
|
||||
|
||||
// Run the model
|
||||
let input = Tensor::<Backend, 4>::zeros([1, 2, 3, 5]);
|
||||
|
||||
let output = model.forward(input);
|
||||
|
||||
let expected = Shape::from([1, 18, 3, 5]);
|
||||
|
||||
assert_eq!(output.shape(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conv2d() {
|
||||
// Initialize the model with weights (loaded from the exported file)
|
||||
let model: conv2d::Model<Backend> = conv2d::Model::default();
|
||||
|
||||
// Run the model with ones as input for easier testing
|
||||
let input = Tensor::<Backend, 4>::ones([2, 4, 10, 15]);
|
||||
|
||||
let output = model.forward(input);
|
||||
|
||||
let expected_shape = Shape::from([2, 6, 6, 15]);
|
||||
assert_eq!(output.shape().clone(), expected_shape);
|
||||
|
||||
// We are using the sum of the output tensor to test the correctness of the conv2d node
|
||||
// because the output tensor is too large to compare with the expected tensor.
|
||||
let output_sum = output.sum().into_scalar();
|
||||
|
||||
let expected_sum = 24.004_995; // from pytorch
|
||||
|
||||
assert!(expected_sum.approx_eq(output_sum, (1.0e-4, 2)));
|
||||
}
|
||||
}
|
Binary file not shown.
|
@ -0,0 +1,57 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# used to generate model: onnx-tests/tests/add/add.onnx
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
# Declare a constant float tensor with ones
|
||||
self.a = torch.ones(1, 1, 1, 4)
|
||||
|
||||
# Declare a scalar
|
||||
self.b = 9.0
|
||||
super(Model, self).__init__()
|
||||
|
||||
def forward(self, x, k):
|
||||
|
||||
# Subtract a constant tensor from a tensor input
|
||||
x = x - self.a
|
||||
|
||||
# Subtract a scalar constant from a scalar input
|
||||
d = k - self.b
|
||||
|
||||
# Sutract a scalar from a tensor
|
||||
x = x - d
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
# Export to onnx
|
||||
model = Model()
|
||||
model.eval()
|
||||
device = torch.device("cpu")
|
||||
onnx_name = "sub.onnx"
|
||||
dummy_input = torch.randn(1, 2, 3, 4, device=device)
|
||||
|
||||
scalar = 3.0
|
||||
|
||||
torch.onnx.export(model, (dummy_input, scalar), onnx_name,
|
||||
verbose=False, opset_version=16)
|
||||
|
||||
print("Finished exporting model to {}".format(onnx_name))
|
||||
|
||||
# Output some test data for use in the test
|
||||
test_input = torch.tensor([[[[1.0, 2.0, 3.0, 4.0]]]])
|
||||
|
||||
print("Test input data: {}, {}".format(test_input, scalar))
|
||||
output = model.forward(test_input, scalar)
|
||||
print("Test output data: {}".format(output))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -9,7 +9,7 @@ use burn::record::{
|
|||
use proc_macro2::TokenStream;
|
||||
use quote::quote;
|
||||
use serde::{ser::SerializeMap, Serialize};
|
||||
use std::path::PathBuf;
|
||||
use std::{collections::HashMap, path::PathBuf};
|
||||
|
||||
/// Burn graph intermediate representation of modules and tensor operations.
|
||||
#[derive(Default, Debug)]
|
||||
|
@ -21,6 +21,8 @@ pub struct BurnGraph<PS: PrecisionSettings> {
|
|||
default: Option<TokenStream>,
|
||||
blank_spaces: bool,
|
||||
gen_new_fn: bool,
|
||||
graph_input_types: Vec<Type>,
|
||||
graph_output_types: Vec<Type>,
|
||||
}
|
||||
|
||||
impl<PS: PrecisionSettings> BurnGraph<PS> {
|
||||
|
@ -163,20 +165,22 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
|
|||
fn build_scope(&mut self) {
|
||||
log::debug!("Building the scope nodes len => '{}'", self.nodes.len());
|
||||
|
||||
let input = self.nodes.first().unwrap();
|
||||
|
||||
fn to_tensor(ty: Type<'_>) -> Option<&TensorType> {
|
||||
fn to_tensor(ty: Type) -> Option<TensorType> {
|
||||
match ty {
|
||||
Type::Tensor(tensor) => Some(tensor),
|
||||
Type::Tensor(tensor) => Some(tensor.clone()),
|
||||
Type::Scalar(_) => None,
|
||||
Type::Other(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
input
|
||||
.input_types()
|
||||
// Register graph tensor input with 0 as node position
|
||||
self.graph_input_types
|
||||
.clone()
|
||||
.into_iter()
|
||||
.flat_map(to_tensor)
|
||||
.for_each(|tensor| self.scope.tensor_register_variable(tensor, 0));
|
||||
.for_each(|tensor| {
|
||||
self.scope.tensor_register_variable(&tensor, 0);
|
||||
});
|
||||
|
||||
self.nodes
|
||||
.iter()
|
||||
|
@ -187,7 +191,7 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
|
|||
.flat_map(to_tensor)
|
||||
.for_each(|tensor| {
|
||||
self.scope
|
||||
.tensor_register_variable(tensor, node_position + 1)
|
||||
.tensor_register_variable(&tensor, node_position + 1)
|
||||
})
|
||||
});
|
||||
|
||||
|
@ -198,7 +202,10 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
|
|||
node.input_types()
|
||||
.into_iter()
|
||||
.flat_map(to_tensor)
|
||||
.for_each(|tensor| self.scope.tensor_register_future_use(tensor, node_position))
|
||||
.for_each(|tensor| {
|
||||
self.scope
|
||||
.tensor_register_future_use(&tensor, node_position)
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -240,16 +247,33 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
|
|||
let name = field.name();
|
||||
let ty = field.ty();
|
||||
|
||||
quote! {
|
||||
#name: #ty,
|
||||
if matches!(&field, Type::Tensor(_)) {
|
||||
quote! {
|
||||
#name: burn::module::Param<#ty>,
|
||||
}
|
||||
} else {
|
||||
quote! {
|
||||
#name: #ty,
|
||||
}
|
||||
}
|
||||
})
|
||||
.for_each(|code| body.extend(code));
|
||||
|
||||
quote! {
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Model<B: Backend> {
|
||||
#body
|
||||
// Add dummy field if no field is present to avoid empty struct
|
||||
// and make sure we can derive Module trait and use it in a model.
|
||||
if body.is_empty() {
|
||||
quote! {
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Model<B: Backend> {
|
||||
_phantom: core::marker::PhantomData<B>,
|
||||
}
|
||||
}
|
||||
} else {
|
||||
quote! {
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Model<B: Backend> {
|
||||
#body
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -269,13 +293,24 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
|
|||
.map(|field| field.name().clone())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
quote! {
|
||||
#[allow(dead_code)]
|
||||
pub fn new() -> Self {
|
||||
#body
|
||||
if fields.is_empty() {
|
||||
quote! {
|
||||
#[allow(dead_code)]
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
_phantom: core::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
quote! {
|
||||
#[allow(dead_code)]
|
||||
pub fn new() -> Self {
|
||||
#body
|
||||
|
||||
Self {
|
||||
#(#fields,)*
|
||||
Self {
|
||||
#(#fields,)*
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -295,12 +330,22 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
|
|||
.map(|field| field.name().clone())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
quote! {
|
||||
pub fn new_with(record: ModelRecord<B>) -> Self {
|
||||
#body
|
||||
if fields.is_empty() {
|
||||
quote! {
|
||||
pub fn new_with(_record: ModelRecord<B>) -> Self {
|
||||
Self {
|
||||
_phantom: core::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
quote! {
|
||||
pub fn new_with(record: ModelRecord<B>) -> Self {
|
||||
#body
|
||||
|
||||
Self {
|
||||
#(#fields,)*
|
||||
Self {
|
||||
#(#fields,)*
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -311,26 +356,19 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
|
|||
let mut output_type_def = quote! {};
|
||||
let mut output_return_def = quote! {};
|
||||
|
||||
self.nodes
|
||||
.first()
|
||||
.unwrap()
|
||||
.input_types()
|
||||
.into_iter()
|
||||
.for_each(|input| {
|
||||
let name = input.name();
|
||||
let ty = input.ty();
|
||||
self.graph_input_types.iter().for_each(|input| {
|
||||
let name = input.name().clone();
|
||||
let ty = input.ty().clone();
|
||||
|
||||
input_def.extend(quote! {
|
||||
#name: #ty,
|
||||
input_def.extend(quote! {
|
||||
#name: #ty,
|
||||
|
||||
})
|
||||
});
|
||||
})
|
||||
});
|
||||
|
||||
let output_types = self.nodes.last().unwrap().output_types();
|
||||
let multiple_output = self.graph_output_types.len() > 1;
|
||||
|
||||
let multiple_output = output_types.len() > 1;
|
||||
|
||||
output_types.into_iter().for_each(|output| {
|
||||
self.graph_output_types.iter().for_each(|output| {
|
||||
let name = output.name();
|
||||
let ty = output.ty();
|
||||
|
||||
|
@ -379,6 +417,48 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Register the input and output types of the graph using the passed in names.
|
||||
/// The names must be unique and match the names of the inputs and outputs of the nodes.
|
||||
/// The order will be preserved.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `input_names` - The names of the inputs of the graph.
|
||||
/// * `output_names` - The names of the outputs of the graph.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the graph is empty.
|
||||
pub fn register_input_output(&mut self, input_names: Vec<String>, output_names: Vec<String>) {
|
||||
assert!(
|
||||
!self.nodes.is_empty(),
|
||||
"Cannot register input and output types for an empty graph."
|
||||
);
|
||||
|
||||
// Get the unique names of each input of the nodes
|
||||
let mut inputs = HashMap::new();
|
||||
let mut outputs = HashMap::new();
|
||||
for node in self.nodes.iter() {
|
||||
for input in node.input_types() {
|
||||
inputs.insert(input.name().to_string(), input);
|
||||
}
|
||||
for output in node.output_types() {
|
||||
outputs.insert(output.name().to_string(), output);
|
||||
}
|
||||
}
|
||||
|
||||
// Get the input and output types of the graph using passed in names
|
||||
input_names.iter().for_each(|input| {
|
||||
self.graph_input_types
|
||||
.push(inputs.get(input).unwrap().clone());
|
||||
});
|
||||
|
||||
output_names.iter().for_each(|output| {
|
||||
self.graph_output_types
|
||||
.push(outputs.get(output).unwrap().clone());
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
|
|
|
@ -77,7 +77,7 @@ pub enum Node<PS: PrecisionSettings> {
|
|||
MaxPool2d(MaxPool2dNode),
|
||||
Linear(LinearNode<PS>),
|
||||
BatchNorm(BatchNormNode<PS>),
|
||||
Constant(ConstantNode),
|
||||
Constant(ConstantNode<PS>),
|
||||
Unary(UnaryNode),
|
||||
Reshape(ReshapeNode),
|
||||
Concat(ConcatNode),
|
||||
|
@ -174,7 +174,6 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for Node<PS> {
|
|||
#[cfg(test)]
|
||||
pub(crate) mod tests {
|
||||
use crate::burn::{
|
||||
codegen::ToTokens,
|
||||
graph::BurnGraph,
|
||||
node::{conv2d::Conv2dNode, matmul::MatmulNode, test::assert_tokens, NodeCodegen},
|
||||
TensorType,
|
||||
|
@ -185,14 +184,18 @@ pub(crate) mod tests {
|
|||
use proc_macro2::TokenStream;
|
||||
use quote::quote;
|
||||
|
||||
fn one_node_graph<T: NodeCodegen<FullPrecisionSettings> + 'static>(
|
||||
pub(crate) fn one_node_graph<T: NodeCodegen<FullPrecisionSettings> + 'static>(
|
||||
node_gen: T,
|
||||
forward: TokenStream,
|
||||
input_names: Vec<String>,
|
||||
output_names: Vec<String>,
|
||||
) {
|
||||
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
|
||||
|
||||
graph.register(node_gen);
|
||||
|
||||
graph.register_input_output(input_names, output_names);
|
||||
|
||||
let expected = quote! {
|
||||
use burn::{
|
||||
module::Module,
|
||||
|
@ -200,11 +203,15 @@ pub(crate) mod tests {
|
|||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Model <B: Backend>{}
|
||||
pub struct Model<B: Backend> {
|
||||
_phantom: core::marker::PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Model <B> {
|
||||
pub fn new_with(record: ModelRecord<B>) -> Self {
|
||||
Self { }
|
||||
pub fn new_with(_record: ModelRecord<B>) -> Self {
|
||||
Self {
|
||||
_phantom: core::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::let_and_return)]
|
||||
|
@ -215,42 +222,6 @@ pub(crate) mod tests {
|
|||
assert_tokens(graph.codegen(), expected);
|
||||
}
|
||||
|
||||
pub(crate) fn codegen_unary_operator<
|
||||
const N: usize,
|
||||
T: NodeCodegen<FullPrecisionSettings> + 'static,
|
||||
>(
|
||||
node_gen: T,
|
||||
function: TokenStream,
|
||||
) {
|
||||
let forward = |function, tensor_dim| {
|
||||
quote! {
|
||||
pub fn forward(&self, tensor1: Tensor<B, #tensor_dim>) -> Tensor<B, #tensor_dim> {
|
||||
#function
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
one_node_graph(node_gen, forward(function, N.to_tokens()));
|
||||
}
|
||||
|
||||
pub(crate) fn codegen_binary_operator<
|
||||
const N: usize,
|
||||
T: NodeCodegen<FullPrecisionSettings> + 'static,
|
||||
>(
|
||||
node_gen: T,
|
||||
function: TokenStream,
|
||||
) {
|
||||
let forward = |function, tensor_dim| {
|
||||
quote! {
|
||||
pub fn forward(&self, tensor1: Tensor<B, #tensor_dim>, tensor2: Tensor<B, #tensor_dim>) -> Tensor<B, #tensor_dim> {
|
||||
#function
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
one_node_graph(node_gen, forward(function, N.to_tokens()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_codegen_two_nodes() {
|
||||
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
|
||||
|
@ -269,6 +240,11 @@ pub(crate) mod tests {
|
|||
Conv2dConfig::new([3, 3], [3, 3]).with_padding(PaddingConfig2d::Valid),
|
||||
));
|
||||
|
||||
graph.register_input_output(
|
||||
vec!["tensor1".to_string(), "tensor2".to_string()],
|
||||
vec!["tensor4".to_string()],
|
||||
);
|
||||
|
||||
let expected = quote! {
|
||||
use burn::{
|
||||
module::Module,
|
||||
|
@ -333,6 +309,11 @@ pub(crate) mod tests {
|
|||
TensorType::new_float("output", 4),
|
||||
));
|
||||
|
||||
graph.register_input_output(
|
||||
vec!["tensor1".to_string(), "tensor2".to_string()],
|
||||
vec!["output".to_string()],
|
||||
);
|
||||
|
||||
let expected = quote! {
|
||||
use burn::{
|
||||
module::Module,
|
||||
|
|
|
@ -102,13 +102,13 @@ macro_rules! batch_norm_serialize {
|
|||
|
||||
impl<PS: PrecisionSettings> NodeCodegen<PS> for BatchNormNode<PS> {
|
||||
fn input_types(&self) -> Vec<Type> {
|
||||
vec![Type::Tensor(&self.input)]
|
||||
vec![Type::Tensor(self.input.clone())]
|
||||
}
|
||||
fn output_types(&self) -> Vec<Type> {
|
||||
vec![Type::Tensor(&self.output)]
|
||||
vec![Type::Tensor(self.output.clone())]
|
||||
}
|
||||
fn field_type(&self) -> Option<Type> {
|
||||
Some(Type::Other(&self.field))
|
||||
Some(Type::Other(self.field.clone()))
|
||||
}
|
||||
|
||||
fn field_init(&self, with_record: bool) -> Option<TokenStream> {
|
||||
|
@ -181,6 +181,8 @@ mod tests {
|
|||
BatchNormConfig::new(128),
|
||||
));
|
||||
|
||||
graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);
|
||||
|
||||
let expected = quote! {
|
||||
use burn::{
|
||||
module::Module,
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use super::{Node, NodeCodegen};
|
||||
use crate::burn::{Scope, TensorType, Type};
|
||||
use crate::burn::{Scope, Type};
|
||||
use burn::record::PrecisionSettings;
|
||||
use proc_macro2::TokenStream;
|
||||
use quote::quote;
|
||||
|
@ -32,9 +32,9 @@ type FnPointer = Arc<dyn Fn(TokenStream, TokenStream) -> TokenStream>;
|
|||
/// Node for all binary operators.
|
||||
#[derive(Clone, new)]
|
||||
pub struct BinaryNode {
|
||||
pub lhs: TensorType,
|
||||
pub rhs: TensorType,
|
||||
pub output: TensorType,
|
||||
pub lhs: Type,
|
||||
pub rhs: Type,
|
||||
pub output: Type,
|
||||
pub binary_type: BinaryType,
|
||||
function: FnPointer,
|
||||
}
|
||||
|
@ -56,17 +56,35 @@ impl std::fmt::Debug for BinaryNode {
|
|||
|
||||
impl<PS: PrecisionSettings> NodeCodegen<PS> for BinaryNode {
|
||||
fn output_types(&self) -> Vec<Type> {
|
||||
vec![Type::Tensor(&self.output)]
|
||||
vec![self.output.clone()]
|
||||
}
|
||||
|
||||
fn input_types(&self) -> Vec<Type> {
|
||||
vec![Type::Tensor(&self.lhs), Type::Tensor(&self.rhs)]
|
||||
vec![self.lhs.clone(), self.rhs.clone()]
|
||||
}
|
||||
|
||||
fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
|
||||
let lhs = scope.tensor_use_owned(&self.lhs, node_position);
|
||||
let rhs = scope.tensor_use_owned(&self.rhs, node_position);
|
||||
let output = &self.output.name;
|
||||
// Get the lhs name in the form of token stream.
|
||||
let lhs = match &self.lhs {
|
||||
Type::Tensor(tensor) => scope.tensor_use_owned(tensor, node_position),
|
||||
Type::Scalar(scalar) => {
|
||||
let name = scalar.name.clone();
|
||||
quote! { #name }
|
||||
}
|
||||
_ => panic!("lhs must be a tensor or scalar"),
|
||||
};
|
||||
|
||||
// Get the rhs name in the form of token stream
|
||||
let rhs = match &self.rhs {
|
||||
Type::Tensor(tensor) => scope.tensor_use_owned(tensor, node_position),
|
||||
Type::Scalar(scalar) => {
|
||||
let name = scalar.name.clone();
|
||||
quote! { #name }
|
||||
}
|
||||
_ => panic!("rhs must be a tensor or scalar"),
|
||||
};
|
||||
|
||||
let output = &self.output.name();
|
||||
let function = (self.function)(lhs, rhs);
|
||||
|
||||
quote! {
|
||||
|
@ -80,27 +98,53 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for BinaryNode {
|
|||
}
|
||||
|
||||
impl BinaryNode {
|
||||
pub(crate) fn add(lhs: TensorType, rhs: TensorType, output: TensorType) -> Self {
|
||||
let function = move |lhs, rhs| quote! { #lhs.add(#rhs) };
|
||||
pub(crate) fn add(lhs: Type, rhs: Type, output: Type) -> Self {
|
||||
let function = match (&lhs, &rhs) {
|
||||
(Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.add(#rhs) },
|
||||
(Type::Tensor(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs.add_scalar(#rhs) },
|
||||
(Type::Scalar(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #rhs.add_scalar(#lhs) },
|
||||
(Type::Scalar(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs + #rhs },
|
||||
_ => panic!("Addition is supported for tensor and scalar only"),
|
||||
};
|
||||
|
||||
Self::new(lhs, rhs, output, BinaryType::Add, Arc::new(function))
|
||||
}
|
||||
|
||||
pub(crate) fn sub(lhs: TensorType, rhs: TensorType, output: TensorType) -> Self {
|
||||
let function = move |lhs, rhs| quote! { #lhs.sub(#rhs) };
|
||||
pub(crate) fn sub(lhs: Type, rhs: Type, output: Type) -> Self {
|
||||
let function = match (&lhs, &rhs) {
|
||||
(Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.sub(#rhs) },
|
||||
(Type::Tensor(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs.sub_scalar(#rhs) },
|
||||
(Type::Scalar(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs - #rhs },
|
||||
_ => panic!("Subtraction is supported for tensor and scalar only"),
|
||||
};
|
||||
|
||||
Self::new(lhs, rhs, output, BinaryType::Sub, Arc::new(function))
|
||||
}
|
||||
|
||||
pub(crate) fn mul(lhs: TensorType, rhs: TensorType, output: TensorType) -> Self {
|
||||
let function = move |lhs, rhs| quote! { #lhs.mul(#rhs) };
|
||||
pub(crate) fn mul(lhs: Type, rhs: Type, output: Type) -> Self {
|
||||
let function = match (&lhs, &rhs) {
|
||||
(Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.mul(#rhs) },
|
||||
(Type::Tensor(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs.mul_scalar(#rhs) },
|
||||
(Type::Scalar(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #rhs.mul_scalar(#lhs) },
|
||||
(Type::Scalar(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs * #rhs },
|
||||
_ => panic!("Multiplication is supported for tensor and scalar only"),
|
||||
};
|
||||
|
||||
Self::new(lhs, rhs, output, BinaryType::Mul, Arc::new(function))
|
||||
}
|
||||
|
||||
pub(crate) fn div(lhs: TensorType, rhs: TensorType, output: TensorType) -> Self {
|
||||
let function = move |lhs, rhs| quote! { #lhs.div(#rhs) };
|
||||
pub(crate) fn div(lhs: Type, rhs: Type, output: Type) -> Self {
|
||||
let function = match (&lhs, &rhs) {
|
||||
(Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.div(#rhs) },
|
||||
(Type::Tensor(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs.div_scalar(#rhs) },
|
||||
(Type::Scalar(_), Type::Scalar(_)) => move |lhs, rhs| quote! { #lhs / #rhs },
|
||||
_ => panic!("Division is supported for tensor and scalar only"),
|
||||
};
|
||||
|
||||
Self::new(lhs, rhs, output, BinaryType::Div, Arc::new(function))
|
||||
}
|
||||
|
||||
pub(crate) fn equal(lhs: TensorType, rhs: TensorType, output: TensorType) -> Self {
|
||||
pub(crate) fn equal(lhs: Type, rhs: Type, output: Type) -> Self {
|
||||
let function = move |lhs, rhs| quote! { #lhs.equal(#rhs) };
|
||||
Self::new(lhs, rhs, output, BinaryType::Equal, Arc::new(function))
|
||||
}
|
||||
|
@ -110,48 +154,93 @@ impl BinaryNode {
|
|||
mod tests {
|
||||
|
||||
use super::*;
|
||||
use crate::burn::node::tests::codegen_binary_operator;
|
||||
use crate::burn::TensorType;
|
||||
use crate::burn::node::tests::one_node_graph;
|
||||
use crate::burn::{ScalarKind, ScalarType, TensorType};
|
||||
|
||||
macro_rules! test_binary_operator {
|
||||
macro_rules! test_binary_operator_on_tensors {
|
||||
($operator:ident) => {{
|
||||
codegen_binary_operator::<4, _>(
|
||||
one_node_graph(
|
||||
BinaryNode::$operator(
|
||||
TensorType::new_float("tensor1", 4),
|
||||
TensorType::new_float("tensor2", 4),
|
||||
TensorType::new_float("tensor3", 4),
|
||||
Type::Tensor(TensorType::new_float("tensor1", 4)),
|
||||
Type::Tensor(TensorType::new_float("tensor2", 4)),
|
||||
Type::Tensor(TensorType::new_float("tensor3", 4)),
|
||||
),
|
||||
quote! {
|
||||
let tensor3 = tensor1.$operator(tensor2);
|
||||
pub fn forward(&self, tensor1: Tensor<B, 4>, tensor2: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let tensor3 = tensor1.$operator(tensor2);
|
||||
|
||||
tensor3
|
||||
tensor3
|
||||
}
|
||||
},
|
||||
vec!["tensor1".to_string(), "tensor2".to_string()],
|
||||
vec!["tensor3".to_string()],
|
||||
);
|
||||
}};
|
||||
}
|
||||
|
||||
macro_rules! test_binary_operator_on_tensor_and_scalar {
|
||||
($operator:ident, $burn_operator:ident) => {{
|
||||
one_node_graph(
|
||||
BinaryNode::$operator(
|
||||
Type::Tensor(TensorType::new_float("tensor1", 4)),
|
||||
Type::Scalar(ScalarType::new("scalar1", ScalarKind::Float32)),
|
||||
Type::Tensor(TensorType::new_float("tensor3", 4)),
|
||||
),
|
||||
quote! {
|
||||
pub fn forward(&self, scalar1: f32, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let tensor3 = tensor1.$burn_operator(scalar1);
|
||||
|
||||
tensor3
|
||||
}
|
||||
},
|
||||
vec!["scalar1".to_string(), "tensor1".to_string()],
|
||||
vec!["tensor3".to_string()],
|
||||
);
|
||||
}};
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_codegen_add() {
|
||||
test_binary_operator!(add);
|
||||
test_binary_operator_on_tensors!(add);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_codegen_add_scalar() {
|
||||
test_binary_operator_on_tensor_and_scalar!(add, add_scalar);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_codegen_sub() {
|
||||
test_binary_operator!(sub);
|
||||
test_binary_operator_on_tensors!(sub);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_codegen_sub_scalar() {
|
||||
test_binary_operator_on_tensor_and_scalar!(sub, sub_scalar);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_codegen_mul() {
|
||||
test_binary_operator!(mul);
|
||||
test_binary_operator_on_tensors!(mul);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_codegen_mul_scalar() {
|
||||
test_binary_operator_on_tensor_and_scalar!(mul, mul_scalar);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_codegen_div() {
|
||||
test_binary_operator!(div);
|
||||
test_binary_operator_on_tensors!(div);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_codegen_div_scalar() {
|
||||
test_binary_operator_on_tensor_and_scalar!(div, div_scalar);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_codegen_equal() {
|
||||
test_binary_operator!(equal);
|
||||
test_binary_operator_on_tensors!(equal);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,11 +14,14 @@ pub struct ConcatNode {
|
|||
|
||||
impl<PS: PrecisionSettings> NodeCodegen<PS> for ConcatNode {
|
||||
fn output_types(&self) -> Vec<Type> {
|
||||
vec![Type::Tensor(&self.output)]
|
||||
vec![Type::Tensor(self.output.clone())]
|
||||
}
|
||||
|
||||
fn input_types(&self) -> Vec<Type> {
|
||||
self.inputs.iter().map(Type::Tensor).collect()
|
||||
self.inputs
|
||||
.iter()
|
||||
.map(|t| Type::Tensor(t.clone()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
|
||||
|
@ -65,6 +68,11 @@ mod tests {
|
|||
1,
|
||||
));
|
||||
|
||||
graph.register_input_output(
|
||||
vec!["tensor1".to_string(), "tensor2".to_string()],
|
||||
vec!["tensor3".to_string()],
|
||||
);
|
||||
|
||||
let expected = quote! {
|
||||
use burn::{
|
||||
module::Module,
|
||||
|
@ -72,12 +80,17 @@ mod tests {
|
|||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Model <B: Backend>{}
|
||||
pub struct Model<B: Backend> {
|
||||
_phantom: core::marker::PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Model <B> {
|
||||
pub fn new_with(record: ModelRecord<B>) -> Self {
|
||||
Self { }
|
||||
pub fn new_with(_record: ModelRecord<B>) -> Self {
|
||||
Self {
|
||||
_phantom: core::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::let_and_return)]
|
||||
pub fn forward(&self, tensor1: Tensor<B, 4>, tensor2: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let tensor3 = burn::tensor::Tensor::cat(vec![tensor1, tensor2], 1);
|
||||
|
|
|
@ -1,72 +1,177 @@
|
|||
use super::{Node, NodeCodegen};
|
||||
use crate::burn::{OtherType, Scope, Type};
|
||||
use burn::record::PrecisionSettings;
|
||||
use crate::burn::{ScalarKind, ScalarType, Scope, TensorType, ToTokens, Type};
|
||||
use burn::{
|
||||
module::ParamId,
|
||||
record::{ParamSerde, PrecisionSettings},
|
||||
tensor::DataSerialize,
|
||||
};
|
||||
use proc_macro2::{Ident, Span, TokenStream};
|
||||
use quote::quote;
|
||||
use serde::Serialize;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ConstantNode {
|
||||
pub struct ConstantNode<PS: PrecisionSettings> {
|
||||
pub name: String,
|
||||
pub value: ConstantValue,
|
||||
output_ty: OtherType,
|
||||
pub value: ConstantValue<PS>,
|
||||
pub output: Type,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum TensorValue<PS: PrecisionSettings> {
|
||||
Float(DataSerialize<PS::FloatElem>),
|
||||
Int(DataSerialize<PS::IntElem>),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, new)]
|
||||
pub enum ConstantValue {
|
||||
Int(i32),
|
||||
Float(f32),
|
||||
Bool(bool),
|
||||
pub enum ConstantValue<PS: PrecisionSettings> {
|
||||
/// Float constant.
|
||||
Float32(f32),
|
||||
Float64(f64),
|
||||
|
||||
/// Integer constant.
|
||||
Int32(i32),
|
||||
Int64(i64),
|
||||
|
||||
/// Tensor constant.
|
||||
Tensor(TensorType, TensorValue<PS>),
|
||||
}
|
||||
|
||||
impl ConstantValue {
|
||||
impl<PS: PrecisionSettings> ConstantValue<PS> {
|
||||
pub fn ty_tokens(&self) -> TokenStream {
|
||||
match self {
|
||||
ConstantValue::Int(_) => quote! { i32 },
|
||||
ConstantValue::Float(_) => quote! { f32 },
|
||||
ConstantValue::Bool(_) => quote! { bool },
|
||||
ConstantValue::Float32(_) => quote! { f32 },
|
||||
ConstantValue::Float64(_) => quote! { f64 },
|
||||
ConstantValue::Int32(_) => quote! { i32 },
|
||||
ConstantValue::Int64(_) => quote! { i64 },
|
||||
ConstantValue::Tensor(tensor_type, _) => {
|
||||
let ty = tensor_type.ty();
|
||||
quote! { burn::module::Param<#ty>}
|
||||
}
|
||||
}
|
||||
}
|
||||
pub fn val_tokens(&self) -> TokenStream {
|
||||
match self {
|
||||
ConstantValue::Int(val) => quote! { #val },
|
||||
ConstantValue::Float(val) => quote! { #val },
|
||||
ConstantValue::Bool(val) => quote! { #val },
|
||||
ConstantValue::Float32(val) => quote! { #val },
|
||||
ConstantValue::Float64(val) => quote! { #val },
|
||||
ConstantValue::Int32(val) => quote! { #val },
|
||||
ConstantValue::Int64(val) => quote! { #val },
|
||||
ConstantValue::Tensor(_, _) => {
|
||||
panic!("Tensor constant is not assignable.")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ConstantNode {
|
||||
pub fn new(name: String, value: ConstantValue) -> Self {
|
||||
let output_ty = OtherType::new(name.clone(), value.ty_tokens());
|
||||
|
||||
impl<PS: PrecisionSettings> ConstantNode<PS> {
|
||||
pub fn new(name: String, value: ConstantValue<PS>, output: Type) -> Self {
|
||||
Self {
|
||||
name,
|
||||
value,
|
||||
output_ty,
|
||||
output,
|
||||
}
|
||||
}
|
||||
pub fn constant_value_into_type(&self) -> Type {
|
||||
let name = Ident::new(self.name.as_str(), Span::call_site());
|
||||
match &self.value {
|
||||
ConstantValue::Float32(_) => Type::Scalar(ScalarType {
|
||||
name,
|
||||
kind: ScalarKind::Float32,
|
||||
}),
|
||||
ConstantValue::Float64(_) => Type::Scalar(ScalarType {
|
||||
name,
|
||||
kind: ScalarKind::Float64,
|
||||
}),
|
||||
ConstantValue::Int32(_) => Type::Scalar(ScalarType {
|
||||
name,
|
||||
kind: ScalarKind::Int32,
|
||||
}),
|
||||
ConstantValue::Int64(_) => Type::Scalar(ScalarType {
|
||||
name,
|
||||
kind: ScalarKind::Int64,
|
||||
}),
|
||||
ConstantValue::Tensor(tensor_type, _) => Type::Tensor(tensor_type.clone()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<PS: PrecisionSettings> NodeCodegen<PS> for ConstantNode {
|
||||
impl<PS: PrecisionSettings> NodeCodegen<PS> for ConstantNode<PS> {
|
||||
fn output_types(&self) -> Vec<Type> {
|
||||
vec![Type::Other(&self.output_ty)]
|
||||
vec![self.output.clone()]
|
||||
}
|
||||
|
||||
fn input_types(&self) -> Vec<Type> {
|
||||
vec![]
|
||||
}
|
||||
|
||||
fn field_type(&self) -> Option<Type> {
|
||||
match &self.value {
|
||||
ConstantValue::Tensor(tensor_type, _) => Some(Type::Tensor(tensor_type.clone())),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn field_init(&self, with_record: bool) -> Option<TokenStream> {
|
||||
match &self.value {
|
||||
ConstantValue::Tensor(tensor_type, _) => {
|
||||
let ty = tensor_type.ty();
|
||||
let name = Ident::new(self.name.as_ref(), Span::call_site());
|
||||
let shape = tensor_type.clone().shape.unwrap().to_tokens();
|
||||
let dim = tensor_type.clone().dim.to_tokens();
|
||||
|
||||
if with_record {
|
||||
Some(quote! {
|
||||
let #name = record.#name.map(|tensor| tensor.set_require_grad(false));
|
||||
})
|
||||
} else {
|
||||
Some(quote! {
|
||||
let #name: burn::module::Param<#ty> = burn::module::Param::new(
|
||||
burn::module::ParamId::new(),
|
||||
Tensor::<B, #dim>::zeros(#shape).set_require_grad(false),
|
||||
);
|
||||
})
|
||||
}
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(&self, _scope: &mut Scope, _node_position: usize) -> TokenStream {
|
||||
let name = Ident::new(self.name.as_ref(), Span::call_site());
|
||||
let val = self.value.val_tokens();
|
||||
let ty = self.value.ty_tokens();
|
||||
let output = self.output.name();
|
||||
|
||||
quote! {
|
||||
let #name: #ty = #val;
|
||||
match &self.value {
|
||||
ConstantValue::Tensor(_, _) => {
|
||||
quote! {
|
||||
let #output = self.#name.val();
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
let val = self.value.val_tokens();
|
||||
let ty = self.value.ty_tokens();
|
||||
|
||||
quote! {
|
||||
let #output: #ty = #val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn into_node(self) -> Node<PS> {
|
||||
Node::Constant(self)
|
||||
}
|
||||
|
||||
fn field_serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
|
||||
if let ConstantValue::Tensor(_, ds) = &self.value {
|
||||
let data: DataSerialize<PS::FloatElem> = match ds {
|
||||
TensorValue::Float(data) => data.clone().convert(),
|
||||
TensorValue::Int(data) => data.clone().convert(),
|
||||
};
|
||||
let data = ParamSerde::new(ParamId::new().into_string(), data);
|
||||
return data.serialize(serializer);
|
||||
}
|
||||
|
||||
S::serialize_none(serializer)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO add test missing for constant node (@antimora 8/2/2023)
|
||||
|
|
|
@ -47,13 +47,13 @@ impl<PS: PrecisionSettings> Conv2dNode<PS> {
|
|||
|
||||
impl<PS: PrecisionSettings> NodeCodegen<PS> for Conv2dNode<PS> {
|
||||
fn input_types(&self) -> Vec<Type> {
|
||||
vec![Type::Tensor(&self.input)]
|
||||
vec![Type::Tensor(self.input.clone())]
|
||||
}
|
||||
fn output_types(&self) -> Vec<Type> {
|
||||
vec![Type::Tensor(&self.output)]
|
||||
vec![Type::Tensor(self.output.clone())]
|
||||
}
|
||||
fn field_type(&self) -> Option<Type> {
|
||||
Some(Type::Other(&self.field))
|
||||
Some(Type::Other(self.field.clone()))
|
||||
}
|
||||
|
||||
fn field_init(&self, with_record: bool) -> Option<TokenStream> {
|
||||
|
@ -154,6 +154,8 @@ mod tests {
|
|||
Conv2dConfig::new([3, 3], [3, 3]).with_padding(PaddingConfig2d::Valid),
|
||||
));
|
||||
|
||||
graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);
|
||||
|
||||
let expected = quote! {
|
||||
use burn::{
|
||||
module::Module,
|
||||
|
|
|
@ -47,14 +47,14 @@ impl<PS: PrecisionSettings> LinearNode<PS> {
|
|||
|
||||
impl<PS: PrecisionSettings> NodeCodegen<PS> for LinearNode<PS> {
|
||||
fn input_types(&self) -> Vec<Type> {
|
||||
vec![Type::Tensor(&self.input)]
|
||||
vec![Type::Tensor(self.input.clone())]
|
||||
}
|
||||
fn output_types(&self) -> Vec<Type> {
|
||||
vec![Type::Tensor(&self.output)]
|
||||
vec![Type::Tensor(self.output.clone())]
|
||||
}
|
||||
|
||||
fn field_type(&self) -> Option<Type> {
|
||||
Some(Type::Other(&self.field))
|
||||
Some(Type::Other(self.field.clone()))
|
||||
}
|
||||
|
||||
fn field_init(&self, with_record: bool) -> Option<TokenStream> {
|
||||
|
@ -136,6 +136,8 @@ mod tests {
|
|||
LinearConfig::new(128, 128),
|
||||
));
|
||||
|
||||
graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);
|
||||
|
||||
let expected = quote! {
|
||||
use burn::{
|
||||
module::Module,
|
||||
|
|
|
@ -13,11 +13,14 @@ pub struct MatmulNode {
|
|||
|
||||
impl<PS: PrecisionSettings> NodeCodegen<PS> for MatmulNode {
|
||||
fn output_types(&self) -> Vec<Type> {
|
||||
vec![Type::Tensor(&self.output)]
|
||||
vec![Type::Tensor(self.output.clone())]
|
||||
}
|
||||
|
||||
fn input_types(&self) -> Vec<Type> {
|
||||
vec![Type::Tensor(&self.lhs), Type::Tensor(&self.rhs)]
|
||||
vec![
|
||||
Type::Tensor(self.lhs.clone()),
|
||||
Type::Tensor(self.rhs.clone()),
|
||||
]
|
||||
}
|
||||
|
||||
fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
|
||||
|
@ -57,6 +60,11 @@ mod tests {
|
|||
TensorType::new_float("tensor3", 4),
|
||||
));
|
||||
|
||||
graph.register_input_output(
|
||||
vec!["tensor1".to_string(), "tensor2".to_string()],
|
||||
vec!["tensor3".to_string()],
|
||||
);
|
||||
|
||||
let expected = quote! {
|
||||
use burn::{
|
||||
module::Module,
|
||||
|
@ -64,12 +72,17 @@ mod tests {
|
|||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Model <B: Backend>{}
|
||||
pub struct Model<B: Backend> {
|
||||
_phantom: core::marker::PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Model <B> {
|
||||
pub fn new_with(record: ModelRecord<B>) -> Self {
|
||||
Self { }
|
||||
pub fn new_with(_record: ModelRecord<B>) -> Self {
|
||||
Self {
|
||||
_phantom: core::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::let_and_return)]
|
||||
pub fn forward(&self, tensor1: Tensor<B, 4>, tensor2: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let tensor3 = tensor1.matmul(tensor2);
|
||||
|
|
|
@ -37,13 +37,13 @@ impl MaxPool2dNode {
|
|||
|
||||
impl<PS: PrecisionSettings> NodeCodegen<PS> for MaxPool2dNode {
|
||||
fn input_types(&self) -> Vec<Type> {
|
||||
vec![Type::Tensor(&self.input)]
|
||||
vec![Type::Tensor(self.input.clone())]
|
||||
}
|
||||
fn output_types(&self) -> Vec<Type> {
|
||||
vec![Type::Tensor(&self.output)]
|
||||
vec![Type::Tensor(self.output.clone())]
|
||||
}
|
||||
fn field_type(&self) -> Option<Type> {
|
||||
Some(Type::Other(&self.field))
|
||||
Some(Type::Other(self.field.clone()))
|
||||
}
|
||||
|
||||
fn field_init(&self, _with_record: bool) -> Option<TokenStream> {
|
||||
|
@ -109,6 +109,8 @@ mod tests {
|
|||
.with_padding(PaddingConfig2d::Valid),
|
||||
));
|
||||
|
||||
graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);
|
||||
|
||||
let expected = quote! {
|
||||
use burn::{
|
||||
module::Module,
|
||||
|
|
|
@ -13,11 +13,11 @@ pub struct ReshapeNode {
|
|||
|
||||
impl<PS: PrecisionSettings> NodeCodegen<PS> for ReshapeNode {
|
||||
fn output_types(&self) -> Vec<Type> {
|
||||
vec![Type::Tensor(&self.output)]
|
||||
vec![Type::Tensor(self.output.clone())]
|
||||
}
|
||||
|
||||
fn input_types(&self) -> Vec<Type> {
|
||||
vec![Type::Tensor(&self.input)]
|
||||
vec![Type::Tensor(self.input.clone())]
|
||||
}
|
||||
|
||||
fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
|
||||
|
@ -56,6 +56,8 @@ mod tests {
|
|||
[4, 4, 4, 4].into(),
|
||||
));
|
||||
|
||||
graph.register_input_output(vec!["tensor1".to_string()], vec!["tensor2".to_string()]);
|
||||
|
||||
let expected = quote! {
|
||||
use burn::{
|
||||
module::Module,
|
||||
|
@ -63,11 +65,15 @@ mod tests {
|
|||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Model <B: Backend>{}
|
||||
pub struct Model<B: Backend> {
|
||||
_phantom: core::marker::PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Model <B> {
|
||||
pub fn new_with(record: ModelRecord<B>) -> Self {
|
||||
Self { }
|
||||
pub fn new_with(_record: ModelRecord<B>) -> Self {
|
||||
Self {
|
||||
_phantom: core::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
#[allow(clippy::let_and_return)]
|
||||
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use super::{Node, NodeCodegen};
|
||||
use crate::burn::{Scope, TensorType, ToTokens, Type};
|
||||
use crate::burn::{Scope, ToTokens, Type};
|
||||
use burn::record::PrecisionSettings;
|
||||
use proc_macro2::TokenStream;
|
||||
use quote::quote;
|
||||
|
@ -11,8 +11,8 @@ type FnPointer = Arc<dyn Fn(TokenStream) -> TokenStream>;
|
|||
/// Node for all unary operators.
|
||||
#[derive(Clone, new)]
|
||||
pub struct UnaryNode {
|
||||
pub input: TensorType,
|
||||
pub output: TensorType,
|
||||
pub input: Type,
|
||||
pub output: Type,
|
||||
pub kind: UnaryNodeKind,
|
||||
function: FnPointer,
|
||||
}
|
||||
|
@ -20,20 +20,22 @@ pub struct UnaryNode {
|
|||
/// Type of unary node.
|
||||
#[derive(Clone)]
|
||||
pub enum UnaryNodeKind {
|
||||
Cast,
|
||||
Flatten,
|
||||
LogSoftmax,
|
||||
Relu,
|
||||
Sigmoid,
|
||||
LogSoftmax,
|
||||
Transpose,
|
||||
}
|
||||
|
||||
impl UnaryNodeKind {
|
||||
pub fn as_str(&self) -> &str {
|
||||
match self {
|
||||
Self::Cast => "cast",
|
||||
Self::Flatten => "flatten",
|
||||
Self::LogSoftmax => "log_softmax",
|
||||
Self::Relu => "relu",
|
||||
Self::Sigmoid => "sigmoid",
|
||||
Self::LogSoftmax => "log_softmax",
|
||||
Self::Transpose => "transpose",
|
||||
}
|
||||
}
|
||||
|
@ -55,16 +57,26 @@ impl std::fmt::Debug for UnaryNode {
|
|||
|
||||
impl<PS: PrecisionSettings> NodeCodegen<PS> for UnaryNode {
|
||||
fn output_types(&self) -> Vec<Type> {
|
||||
vec![Type::Tensor(&self.output)]
|
||||
vec![self.output.clone()]
|
||||
}
|
||||
|
||||
fn input_types(&self) -> Vec<Type> {
|
||||
vec![Type::Tensor(&self.input)]
|
||||
vec![self.input.clone()]
|
||||
}
|
||||
|
||||
fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
|
||||
let input = scope.tensor_use_owned(&self.input, node_position);
|
||||
let output = &self.output.name;
|
||||
// Get the lhs name in the form of token stream.
|
||||
let input = match &self.input {
|
||||
Type::Tensor(tensor) => scope.tensor_use_owned(tensor, node_position),
|
||||
Type::Scalar(scalar) => {
|
||||
let name = scalar.name.clone();
|
||||
quote! { #name }
|
||||
}
|
||||
_ => panic!("lhs must be a tensor or scalar"),
|
||||
};
|
||||
|
||||
// let input = scope.tensor_use_owned(&self.input, node_position);
|
||||
let output = &self.output.name();
|
||||
let function = (self.function)(input);
|
||||
|
||||
quote! {
|
||||
|
@ -78,12 +90,7 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for UnaryNode {
|
|||
}
|
||||
|
||||
impl UnaryNode {
|
||||
pub(crate) fn flatten(
|
||||
input: TensorType,
|
||||
output: TensorType,
|
||||
start_dim: usize,
|
||||
end_dim: usize,
|
||||
) -> Self {
|
||||
pub(crate) fn flatten(input: Type, output: Type, start_dim: usize, end_dim: usize) -> Self {
|
||||
let start_dim = start_dim.to_tokens();
|
||||
let end_dim = end_dim.to_tokens();
|
||||
let function = move |input| quote! { #input.flatten(#start_dim, #end_dim) };
|
||||
|
@ -91,109 +98,193 @@ impl UnaryNode {
|
|||
Self::new(input, output, UnaryNodeKind::Flatten, Arc::new(function))
|
||||
}
|
||||
|
||||
pub(crate) fn relu(input: TensorType, output: TensorType) -> Self {
|
||||
pub(crate) fn relu(input: Type, output: Type) -> Self {
|
||||
let function = move |input| quote! { burn::tensor::activation::relu(#input) };
|
||||
Self::new(input, output, UnaryNodeKind::Relu, Arc::new(function))
|
||||
}
|
||||
|
||||
pub(crate) fn sigmoid(input: TensorType, output: TensorType) -> Self {
|
||||
pub(crate) fn sigmoid(input: Type, output: Type) -> Self {
|
||||
let function = move |input| quote! { burn::tensor::activation::sigmoid(#input) };
|
||||
Self::new(input, output, UnaryNodeKind::Sigmoid, Arc::new(function))
|
||||
}
|
||||
|
||||
pub(crate) fn log_softmax(input: TensorType, output: TensorType, dim: usize) -> Self {
|
||||
pub(crate) fn log_softmax(input: Type, output: Type, dim: usize) -> Self {
|
||||
let dim = dim.to_tokens();
|
||||
let function = move |input| quote! { burn::tensor::activation::log_softmax(#input, #dim) };
|
||||
Self::new(input, output, UnaryNodeKind::LogSoftmax, Arc::new(function))
|
||||
}
|
||||
|
||||
pub(crate) fn transpose(input: TensorType, output: TensorType) -> Self {
|
||||
pub(crate) fn transpose(input: Type, output: Type) -> Self {
|
||||
let function = move |input| quote! { #input.transpose() };
|
||||
Self::new(input, output, UnaryNodeKind::Transpose, Arc::new(function))
|
||||
}
|
||||
|
||||
/// Casts the input to the output type.
|
||||
///
|
||||
/// Currently this function only supports the following conversions:
|
||||
/// 1) scalar -> scalar
|
||||
///
|
||||
/// TODO: Implement the following conversions:
|
||||
/// 2) tensor int -> tensor float
|
||||
/// 3) tensor float -> tensor int
|
||||
/// 4) tensor -> scalar
|
||||
/// 5) scalar -> tensor
|
||||
pub(crate) fn cast(input: Type, output: Type) -> Self {
|
||||
let function = match output.clone() {
|
||||
Type::Scalar(scalar) => {
|
||||
let ty = scalar.ty();
|
||||
move |input| quote! { #input as #ty }
|
||||
}
|
||||
Type::Tensor(_tensor) => {
|
||||
// TODO: Implement this after tensor Int is implemented (@antimora 8/2/2023)
|
||||
// TODO: If the input is scalar and the output type is a tensor,
|
||||
// we should generate another code block. (@antimora 8/4/2023)
|
||||
// Tensor::from_data(Data::from([#input]).convert()).unsqueeze();
|
||||
todo!()
|
||||
}
|
||||
|
||||
_ => panic!("output must be a tensor"),
|
||||
};
|
||||
|
||||
Self::new(input, output, UnaryNodeKind::Cast, Arc::new(function))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::burn::node::tests::codegen_unary_operator;
|
||||
use crate::burn::TensorType;
|
||||
use crate::burn::node::tests::one_node_graph;
|
||||
use crate::burn::{ScalarKind, ScalarType, TensorType};
|
||||
|
||||
#[test]
|
||||
fn test_unary_codegen_flatten() {
|
||||
codegen_unary_operator::<4, _>(
|
||||
one_node_graph(
|
||||
UnaryNode::flatten(
|
||||
TensorType::new_float("tensor1", 4),
|
||||
TensorType::new_float("tensor2", 4),
|
||||
Type::Tensor(TensorType::new_float("tensor1", 4)),
|
||||
Type::Tensor(TensorType::new_float("tensor2", 4)),
|
||||
1,
|
||||
2,
|
||||
),
|
||||
quote! {
|
||||
let tensor2 = tensor1.flatten(1, 2);
|
||||
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let tensor2 = tensor1.flatten(1, 2);
|
||||
|
||||
tensor2
|
||||
tensor2
|
||||
}
|
||||
},
|
||||
vec!["tensor1".to_string()],
|
||||
vec!["tensor2".to_string()],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unary_codegen_relu() {
|
||||
codegen_unary_operator::<4, _>(
|
||||
one_node_graph(
|
||||
UnaryNode::relu(
|
||||
TensorType::new_float("tensor1", 4),
|
||||
TensorType::new_float("tensor2", 4),
|
||||
Type::Tensor(TensorType::new_float("tensor1", 4)),
|
||||
Type::Tensor(TensorType::new_float("tensor2", 4)),
|
||||
),
|
||||
quote! {
|
||||
let tensor2 = burn::tensor::activation::relu(tensor1);
|
||||
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let tensor2 = burn::tensor::activation::relu(tensor1);
|
||||
|
||||
tensor2
|
||||
tensor2
|
||||
}
|
||||
},
|
||||
vec!["tensor1".to_string()],
|
||||
vec!["tensor2".to_string()],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unary_codegen_sigmoid() {
|
||||
codegen_unary_operator::<4, _>(
|
||||
one_node_graph(
|
||||
UnaryNode::sigmoid(
|
||||
TensorType::new_float("tensor1", 4),
|
||||
TensorType::new_float("tensor2", 4),
|
||||
Type::Tensor(TensorType::new_float("tensor1", 4)),
|
||||
Type::Tensor(TensorType::new_float("tensor2", 4)),
|
||||
),
|
||||
quote! {
|
||||
let tensor2 = burn::tensor::activation::sigmoid(tensor1);
|
||||
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let tensor2 = burn::tensor::activation::sigmoid(tensor1);
|
||||
|
||||
tensor2
|
||||
tensor2
|
||||
}
|
||||
},
|
||||
vec!["tensor1".to_string()],
|
||||
vec!["tensor2".to_string()],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unary_codegen_log_softmax() {
|
||||
codegen_unary_operator::<4, _>(
|
||||
one_node_graph(
|
||||
UnaryNode::log_softmax(
|
||||
TensorType::new_float("tensor1", 4),
|
||||
TensorType::new_float("tensor2", 4),
|
||||
Type::Tensor(TensorType::new_float("tensor1", 4)),
|
||||
Type::Tensor(TensorType::new_float("tensor2", 4)),
|
||||
1,
|
||||
),
|
||||
quote! {
|
||||
let tensor2 = burn::tensor::activation::log_softmax(tensor1, 1);
|
||||
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let tensor2 = burn::tensor::activation::log_softmax(tensor1, 1);
|
||||
|
||||
tensor2
|
||||
tensor2
|
||||
}
|
||||
},
|
||||
vec!["tensor1".to_string()],
|
||||
vec!["tensor2".to_string()],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unary_codegen_transpose() {
|
||||
codegen_unary_operator::<4, _>(
|
||||
one_node_graph(
|
||||
UnaryNode::transpose(
|
||||
TensorType::new_float("tensor1", 4),
|
||||
TensorType::new_float("tensor2", 4),
|
||||
Type::Tensor(TensorType::new_float("tensor1", 4)),
|
||||
Type::Tensor(TensorType::new_float("tensor2", 4)),
|
||||
),
|
||||
quote! {
|
||||
let tensor2 = tensor1.transpose();
|
||||
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let tensor2 = tensor1.transpose();
|
||||
|
||||
tensor2
|
||||
tensor2
|
||||
}
|
||||
},
|
||||
vec!["tensor1".to_string()],
|
||||
vec!["tensor2".to_string()],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unary_codegen_cast() {
|
||||
one_node_graph(
|
||||
UnaryNode::cast(
|
||||
Type::Scalar(ScalarType::new("scalar1", ScalarKind::Float64)),
|
||||
Type::Scalar(ScalarType::new("scalar2", ScalarKind::Float32)),
|
||||
),
|
||||
quote! {
|
||||
pub fn forward(&self, scalar1: f64) -> f32 {
|
||||
let scalar2 = scalar1 as f32;
|
||||
|
||||
scalar2
|
||||
}
|
||||
},
|
||||
vec!["scalar1".to_string()],
|
||||
vec!["scalar2".to_string()],
|
||||
);
|
||||
one_node_graph(
|
||||
UnaryNode::cast(
|
||||
Type::Scalar(ScalarType::new("scalar1", ScalarKind::Float32)),
|
||||
Type::Scalar(ScalarType::new("scalar2", ScalarKind::Float64)),
|
||||
),
|
||||
quote! {
|
||||
pub fn forward(&self, scalar1: f32) -> f64 {
|
||||
let scalar2 = scalar1 as f64;
|
||||
|
||||
scalar2
|
||||
}
|
||||
},
|
||||
vec!["scalar1".to_string()],
|
||||
vec!["scalar2".to_string()],
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,64 +10,114 @@ pub struct TensorType {
|
|||
pub name: Ident,
|
||||
pub dim: usize,
|
||||
pub kind: TensorKind,
|
||||
pub shape: Option<Vec<usize>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum TensorKind {
|
||||
Int,
|
||||
Float,
|
||||
Bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ScalarKind {
|
||||
Int32,
|
||||
Int64,
|
||||
Float32,
|
||||
Float64,
|
||||
Bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ScalarType {
|
||||
pub name: Ident,
|
||||
pub kind: ScalarKind,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct OtherType {
|
||||
pub name: Ident,
|
||||
pub ty: TokenStream,
|
||||
}
|
||||
|
||||
pub enum Type<'a> {
|
||||
Tensor(&'a TensorType),
|
||||
Other(&'a OtherType),
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Type {
|
||||
/// Tensor type.
|
||||
Tensor(TensorType),
|
||||
|
||||
/// Scalar type.
|
||||
Scalar(ScalarType),
|
||||
|
||||
// Other type (more flexible type).
|
||||
Other(OtherType),
|
||||
}
|
||||
|
||||
impl<'a> Type<'a> {
|
||||
impl Type {
|
||||
pub fn name(&self) -> &Ident {
|
||||
match self {
|
||||
Type::Tensor(tensor) => &tensor.name,
|
||||
Type::Scalar(scalar) => &scalar.name,
|
||||
Type::Other(other) => &other.name,
|
||||
}
|
||||
}
|
||||
pub fn ty(&self) -> TokenStream {
|
||||
match self {
|
||||
Type::Tensor(tensor) => tensor.ty(),
|
||||
Type::Scalar(scalar) => scalar.ty(),
|
||||
Type::Other(other) => other.ty(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ScalarType {
|
||||
pub fn new<S: AsRef<str>>(name: S, kind: ScalarKind) -> Self {
|
||||
Self {
|
||||
name: Ident::new(name.as_ref(), Span::call_site()),
|
||||
kind,
|
||||
}
|
||||
}
|
||||
pub fn ty(&self) -> TokenStream {
|
||||
match self.kind {
|
||||
ScalarKind::Int32 => quote! { i32 },
|
||||
ScalarKind::Int64 => quote! { i64 },
|
||||
ScalarKind::Float32 => quote! { f32 },
|
||||
ScalarKind::Float64 => quote! { f64 },
|
||||
ScalarKind::Bool => quote! { bool },
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TensorType {
|
||||
pub fn new<S: AsRef<str>>(name: S, dim: usize, kind: TensorKind) -> Self {
|
||||
pub fn new<S: AsRef<str>>(
|
||||
name: S,
|
||||
dim: usize,
|
||||
kind: TensorKind,
|
||||
shape: Option<Vec<usize>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
name: Ident::new(name.as_ref(), Span::call_site()),
|
||||
dim,
|
||||
kind,
|
||||
shape,
|
||||
}
|
||||
}
|
||||
pub fn new_float<S: AsRef<str>>(name: S, dim: usize) -> Self {
|
||||
Self::new(name, dim, TensorKind::Float)
|
||||
Self::new(name, dim, TensorKind::Float, None)
|
||||
}
|
||||
|
||||
pub fn new_int<S: AsRef<str>>(name: S, dim: usize) -> Self {
|
||||
Self::new(name, dim, TensorKind::Int)
|
||||
Self::new(name, dim, TensorKind::Int, None)
|
||||
}
|
||||
|
||||
pub fn new_bool<S: AsRef<str>>(name: S, dim: usize) -> Self {
|
||||
Self::new(name, dim, TensorKind::Bool)
|
||||
Self::new(name, dim, TensorKind::Bool, None)
|
||||
}
|
||||
|
||||
pub fn ty(&self) -> TokenStream {
|
||||
let dim = self.dim.to_tokens();
|
||||
|
||||
// TODO use passed elem kind and do not assume float (@antimora 8/1/2023)
|
||||
quote! {
|
||||
Tensor<B, #dim>
|
||||
}
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use protobuf::Enum;
|
||||
|
||||
use super::{
|
||||
ir::{ArgType, Argument, AttributeValue, Node, NodeType, TensorArg},
|
||||
ir::{ArgType, Argument, AttributeValue, ElementType, Node, NodeType, TensorArg},
|
||||
op_configuration::flatten_config,
|
||||
protos::tensor_proto::DataType,
|
||||
};
|
||||
|
||||
struct TensorDimUpdater {
|
||||
|
@ -70,15 +73,13 @@ pub fn dim_inference(
|
|||
NodeType::Sub => same_as_input(node),
|
||||
NodeType::Pow => same_as_input(node),
|
||||
NodeType::Mul => same_as_input(node),
|
||||
NodeType::Cast => same_as_input(node),
|
||||
NodeType::Cast => cast_update_outputs(node),
|
||||
NodeType::Div => same_as_input(node),
|
||||
NodeType::Sqrt => same_as_input(node),
|
||||
NodeType::Softmax => same_as_input(node),
|
||||
NodeType::Erf => same_as_input(node),
|
||||
NodeType::ReduceMean => mean_update_outputs(node),
|
||||
NodeType::Constant => {
|
||||
node.outputs[0].ty = ArgType::Constant;
|
||||
}
|
||||
NodeType::Constant => constant_update_outputs(node),
|
||||
NodeType::Equal => same_as_input(node),
|
||||
NodeType::Shape => shape_update_outputs(node),
|
||||
NodeType::Unsqueeze => unsqueeze_update_outputs(node),
|
||||
|
@ -89,7 +90,9 @@ pub fn dim_inference(
|
|||
NodeType::Concat => concat_update_outputs(node),
|
||||
NodeType::Reshape => reshape_update_outputs(node),
|
||||
NodeType::Dropout => same_as_input(node),
|
||||
NodeType::GlobalAveragePool => same_as_input(node), //FIXME use correct output
|
||||
|
||||
//FIXME use correct output for GAP (@antimora 8/1/2023)
|
||||
NodeType::GlobalAveragePool => same_as_input(node),
|
||||
_ => todo!(
|
||||
"shape inference for {:?} is not implemented",
|
||||
node.node_type
|
||||
|
@ -102,6 +105,20 @@ pub fn dim_inference(
|
|||
updater.update_arguments(graph_outputs);
|
||||
}
|
||||
|
||||
fn constant_update_outputs(node: &mut Node) {
|
||||
// Fix the tensor dimension of the output when the value is tensor
|
||||
let output = &mut node.outputs[0];
|
||||
match node.attrs.get("value") {
|
||||
Some(value) => match &value {
|
||||
AttributeValue::Tensor(tensor) => {
|
||||
output.ty = ArgType::Tensor(TensorArg { dim: tensor.dim });
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
None => panic!("Constant node must have a value attribute"),
|
||||
};
|
||||
}
|
||||
|
||||
/// Infer the shape of the output tensor of a Conv2d node
|
||||
fn linear_update_outputs(node: &mut Node) {
|
||||
if node.inputs.len() != 1 {
|
||||
|
@ -119,6 +136,49 @@ fn linear_update_outputs(node: &mut Node) {
|
|||
}
|
||||
}
|
||||
|
||||
/// Update the output type using "to" attribute
|
||||
fn cast_update_outputs(node: &mut Node) {
|
||||
if node.inputs.len() != 1 {
|
||||
panic!("Cast: multiple inputs are not supported");
|
||||
}
|
||||
let output = &mut node.outputs[0];
|
||||
|
||||
// Extract cast type and update the output tensor
|
||||
let elem_type = match node.attrs.get("to") {
|
||||
Some(value) => match &value {
|
||||
AttributeValue::Int64(type_id) => match DataType::from_i32(*type_id as i32).unwrap() {
|
||||
DataType::FLOAT => ElementType::Float32,
|
||||
DataType::INT32 => ElementType::Int32,
|
||||
DataType::INT64 => ElementType::Int64,
|
||||
DataType::DOUBLE => ElementType::Float64,
|
||||
_ => panic!("Cast: unsupported type"),
|
||||
},
|
||||
_ => panic!("'to' attribute must be an Int64"),
|
||||
},
|
||||
None => panic!("Constant node must have a value attribute"),
|
||||
};
|
||||
|
||||
match output.ty.clone() {
|
||||
ArgType::Tensor(tensor) => {
|
||||
if tensor.dim == 0 {
|
||||
// treat 0-dim tensor as scalar
|
||||
output.ty = ArgType::Scalar(elem_type);
|
||||
} else {
|
||||
todo!("Cast: update tensor type");
|
||||
// TODO track the type of the tensor elements (@antimora 8/1/2023)
|
||||
// output.ty = ArgType::Tensor(TensorArg {
|
||||
// dim: tensor.dim,
|
||||
// elem_type,
|
||||
// });
|
||||
}
|
||||
}
|
||||
ArgType::Scalar(_scalar) => {
|
||||
output.ty = ArgType::Scalar(elem_type);
|
||||
}
|
||||
_ => panic!("Only tensor input is valid"),
|
||||
}
|
||||
}
|
||||
|
||||
fn concat_update_outputs(node: &mut Node) {
|
||||
let tensor = node
|
||||
.inputs
|
||||
|
@ -184,7 +244,7 @@ fn unsqueeze_update_outputs(node: &mut Node) {
|
|||
let dim = match node_input.clone().ty {
|
||||
ArgType::Tensor(tensor) => tensor.dim,
|
||||
ArgType::Shape(dim) => dim,
|
||||
ArgType::Constant => panic!("Needs shape or tensor"),
|
||||
ArgType::Scalar(_) => panic!("Needs shape or tensor"),
|
||||
};
|
||||
|
||||
node.outputs[0].ty = ArgType::Tensor(TensorArg { dim: dim + 1 });
|
||||
|
|
|
@ -36,6 +36,16 @@ pub fn parse_onnx(onnx_path: &Path) -> ONNXGraph {
|
|||
let onnx_model: ModelProto =
|
||||
Message::parse_from_reader(&mut file).expect("Unable to parse ONNX file");
|
||||
|
||||
log::debug!("Number of nodes: {:?}", onnx_model.graph.node.len());
|
||||
log::debug!("Number of inputs: {:?}", onnx_model.graph.input.len());
|
||||
|
||||
log::debug!(
|
||||
"Number of initializers: {:?}",
|
||||
onnx_model.graph.initializer.len()
|
||||
);
|
||||
|
||||
log::debug!("Number of outputs: {:?}", onnx_model.graph.output.len());
|
||||
|
||||
// Convert the nodes
|
||||
let mut nodes: Vec<Node> = vec![];
|
||||
for onnx_node in onnx_model.graph.node.iter() {
|
||||
|
@ -46,10 +56,7 @@ pub fn parse_onnx(onnx_path: &Path) -> ONNXGraph {
|
|||
move_inputs_to_state(&mut nodes, &onnx_model.graph.initializer);
|
||||
|
||||
// Get the topological sort of the nodes and the top nodes
|
||||
let (ts, top_nodes) = get_top_nodes(&nodes);
|
||||
|
||||
// Sort the nodes
|
||||
top_sort_nodes(&mut nodes, ts);
|
||||
top_sort_nodes(&mut nodes);
|
||||
|
||||
// Collect inputs, outputs and initializers
|
||||
let check_if_initializer: HashSet<String> = onnx_model
|
||||
|
@ -58,7 +65,8 @@ pub fn parse_onnx(onnx_path: &Path) -> ONNXGraph {
|
|||
.iter()
|
||||
.map(|x| x.name.clone())
|
||||
.collect();
|
||||
let mut inputs = collect_inputs(&onnx_model, &check_if_initializer, top_nodes);
|
||||
let mut inputs = collect_inputs(&onnx_model, &check_if_initializer);
|
||||
|
||||
let mut outputs = collect_outputs(&onnx_model, check_if_initializer);
|
||||
let states = collect_states(onnx_model);
|
||||
|
||||
|
@ -90,10 +98,7 @@ fn collect_states(onnx_model: ModelProto) -> Vec<State> {
|
|||
|
||||
for initializer in onnx_model.graph.initializer.iter() {
|
||||
let tensor_proto = initializer.clone();
|
||||
|
||||
let name = tensor_proto.name.clone();
|
||||
|
||||
// FIXME data conversion for the tensor is incorrect
|
||||
let tensor: Tensor = tensor_proto.try_into().unwrap();
|
||||
let ty = StateType::Tensor(tensor);
|
||||
let arg = State { name, ty };
|
||||
|
@ -108,7 +113,6 @@ fn collect_outputs(
|
|||
onnx_model: &ModelProto,
|
||||
check_if_initializer: HashSet<String>,
|
||||
) -> Vec<Argument> {
|
||||
// TODO: filter out the outputs that are not used in the graph
|
||||
let outputs: Vec<Argument> = onnx_model
|
||||
.graph
|
||||
.output
|
||||
|
@ -123,42 +127,30 @@ fn collect_outputs(
|
|||
fn collect_inputs(
|
||||
onnx_model: &ModelProto,
|
||||
check_if_initializer: &HashSet<String>,
|
||||
top_nodes: HashSet<String>,
|
||||
) -> Vec<Argument> {
|
||||
// Get the unique inputs
|
||||
let inputs: Vec<Argument> = onnx_model
|
||||
.graph
|
||||
.input
|
||||
.iter()
|
||||
.filter(|x| !check_if_initializer.contains(x.name.as_str()))
|
||||
.filter(|x| top_nodes.contains(&x.name))
|
||||
// .filter(|x| top_nodes.contains(&x.name))
|
||||
.map(|x| Argument::try_from(x.clone()).unwrap())
|
||||
.collect();
|
||||
inputs
|
||||
|
||||
// Convert to a vector and return
|
||||
inputs.into_iter().collect()
|
||||
}
|
||||
|
||||
/// Sort the nodes in topological order
|
||||
fn top_sort_nodes(nodes: &mut Vec<Node>, mut ts: TopologicalSort<Node>) {
|
||||
fn top_sort_nodes(nodes: &mut Vec<Node>) {
|
||||
let mut ts = topsort(nodes);
|
||||
*nodes = vec![];
|
||||
while let Some(node) = ts.pop() {
|
||||
nodes.push(node);
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the top nodes in the graph
|
||||
fn get_top_nodes(nodes: &Vec<Node>) -> (TopologicalSort<Node>, HashSet<String>) {
|
||||
// Get the names of the top nodes (first nodes in the graph to receive the input)
|
||||
// Sometimes onnx will pass inputs to be used as weights and biases but they are not truly inputs
|
||||
let ts = topsort(nodes);
|
||||
let mut top_nodes: HashSet<String> = HashSet::new();
|
||||
|
||||
for node in ts.peek_all() {
|
||||
for input in node.inputs.iter() {
|
||||
top_nodes.insert(input.name.clone());
|
||||
}
|
||||
}
|
||||
(ts, top_nodes)
|
||||
}
|
||||
|
||||
fn to_string(bytes: Vec<u8>) -> String {
|
||||
from_utf8(bytes.as_slice()).unwrap().to_string()
|
||||
}
|
||||
|
|
|
@ -14,9 +14,9 @@ pub struct Argument {
|
|||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ArgType {
|
||||
Tensor(TensorArg),
|
||||
Scalar(ElementType),
|
||||
Shape(usize),
|
||||
Constant,
|
||||
Tensor(TensorArg),
|
||||
}
|
||||
|
||||
#[derive(new, Default, Debug, Clone)]
|
||||
|
@ -142,6 +142,15 @@ impl core::hash::Hash for Argument {
|
|||
}
|
||||
}
|
||||
|
||||
impl Eq for Argument {}
|
||||
|
||||
// Required by HashSet
|
||||
impl PartialEq for Argument {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.name == other.name
|
||||
}
|
||||
}
|
||||
|
||||
/// The list of supported node types (ONNX operators and some extra ones to map easily to Burn's ops)
|
||||
#[derive(Debug, Hash, Eq, PartialEq, EnumString, Clone, Display)]
|
||||
pub enum NodeType {
|
||||
|
|
|
@ -16,7 +16,7 @@ use crate::{
|
|||
batch_norm::BatchNormNode,
|
||||
binary::BinaryNode,
|
||||
concat::ConcatNode,
|
||||
constant::{ConstantNode, ConstantValue},
|
||||
constant::{ConstantNode, ConstantValue, TensorValue},
|
||||
conv2d::Conv2dNode,
|
||||
linear::LinearNode,
|
||||
matmul::MatmulNode,
|
||||
|
@ -24,7 +24,7 @@ use crate::{
|
|||
reshape::ReshapeNode,
|
||||
unary::UnaryNode,
|
||||
},
|
||||
TensorType,
|
||||
ScalarKind, ScalarType, TensorKind, TensorType, Type,
|
||||
},
|
||||
format_tokens,
|
||||
logger::init_log,
|
||||
|
@ -39,7 +39,7 @@ use crate::{
|
|||
|
||||
use super::{
|
||||
from_onnx::parse_onnx,
|
||||
ir::{ArgType, Argument, ONNXGraph, State, StateType, Tensor, TensorData},
|
||||
ir::{ArgType, Argument, ElementType, ONNXGraph, State, StateType, Tensor, TensorData},
|
||||
op_configuration::concat_config,
|
||||
};
|
||||
|
||||
|
@ -98,6 +98,8 @@ impl ModelGen {
|
|||
fn run(&self, is_build_script: bool) {
|
||||
log::info!("Starting to convert ONNX to Burn");
|
||||
|
||||
log::info!("Starting to convert ONNX to Burn");
|
||||
|
||||
// prepend the out_dir to the cargo_out_dir if this is a build script
|
||||
let out_dir = if is_build_script {
|
||||
let cargo_out_dir = env::var("OUT_DIR").expect("OUT_DIR env is not set");
|
||||
|
@ -112,6 +114,8 @@ impl ModelGen {
|
|||
|
||||
log::debug!("Output directory: {:?}", out_dir);
|
||||
|
||||
log::debug!("Output directory: {:?}", out_dir);
|
||||
|
||||
create_dir_all(&out_dir).unwrap();
|
||||
|
||||
for input in self.inputs.iter() {
|
||||
|
@ -122,10 +126,16 @@ impl ModelGen {
|
|||
log::debug!("Input file name: {:?}", file_name);
|
||||
log::debug!("Output file: {:?}", out_file);
|
||||
|
||||
log::info!("Converting {:?}", input);
|
||||
log::debug!("Input file name: {:?}", file_name);
|
||||
log::debug!("Output file: {:?}", out_file);
|
||||
|
||||
Self::generate_model(self.development, input, out_file);
|
||||
}
|
||||
|
||||
log::info!("Finished converting ONNX to Burn");
|
||||
|
||||
log::info!("Finished converting ONNX to Burn");
|
||||
}
|
||||
|
||||
/// Generate model source code and model state.
|
||||
|
@ -134,6 +144,10 @@ impl ModelGen {
|
|||
log::debug!("Development mode: {:?}", development);
|
||||
log::debug!("Output file: {:?}", out_file);
|
||||
|
||||
log::info!("Generating model from {:?}", input);
|
||||
log::debug!("Development mode: {:?}", development);
|
||||
log::debug!("Output file: {:?}", out_file);
|
||||
|
||||
let graph = parse_onnx(input.as_ref());
|
||||
|
||||
if development {
|
||||
|
@ -161,6 +175,8 @@ impl ModelGen {
|
|||
fs::write(out_file.with_extension("rs"), code_str).unwrap();
|
||||
|
||||
log::info!("Model generated");
|
||||
|
||||
log::info!("Model generated");
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -186,64 +202,112 @@ impl ONNXGraph {
|
|||
NodeType::Relu => graph.register(Self::relu_conversion(node)),
|
||||
NodeType::Flatten => graph.register(Self::flatten_conversion(node)),
|
||||
NodeType::LogSoftmax => graph.register(Self::log_softmax_conversion(node)),
|
||||
NodeType::Constant => graph.register(Self::constant_conversion(node)),
|
||||
NodeType::Constant => graph.register(Self::constant_conversion::<PS>(node)),
|
||||
NodeType::Reshape => graph.register(Self::reshape_conversion(node)),
|
||||
NodeType::Sigmoid => graph.register(Self::sigmoid_conversion(node)),
|
||||
NodeType::Transpose => graph.register(Self::transpose_conversion(node)),
|
||||
NodeType::Concat => graph.register(Self::concat_conversion(node)),
|
||||
NodeType::Cast => graph.register(Self::cast_conversion(node)),
|
||||
_ => panic!("Unsupported node conversion {}", node.node_type),
|
||||
}
|
||||
}
|
||||
|
||||
// Get input and output names
|
||||
let input_names = self
|
||||
.inputs
|
||||
.iter()
|
||||
.map(|input| input.name.clone())
|
||||
.collect::<Vec<_>>();
|
||||
let output_names = self
|
||||
.outputs
|
||||
.iter()
|
||||
.map(|output| output.name.clone())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Register inputs and outputs with the graph
|
||||
graph.register_input_output(input_names, output_names);
|
||||
|
||||
graph
|
||||
}
|
||||
|
||||
fn constant_conversion(mut node: Node) -> ConstantNode {
|
||||
fn constant_conversion<PS: PrecisionSettings>(mut node: Node) -> ConstantNode<PS> {
|
||||
let output = node.outputs.get(0).unwrap();
|
||||
|
||||
let value = node.attrs.remove("value").unwrap();
|
||||
|
||||
let value = match value {
|
||||
AttributeValue::Float32(val) => ConstantValue::Float(val),
|
||||
AttributeValue::Int64(val) => ConstantValue::Int(val as i32),
|
||||
AttributeValue::Float32s(val) => ConstantValue::Float(val[0]),
|
||||
AttributeValue::Int64s(val) => ConstantValue::Int(val[0] as i32),
|
||||
_ => panic!("Unsupported constant node: {:?}", node),
|
||||
AttributeValue::Float32(val) => ConstantValue::Float32(val),
|
||||
AttributeValue::Int64(val) => ConstantValue::Int64(val),
|
||||
AttributeValue::Tensor(tensor) => {
|
||||
if tensor.dim == 0 {
|
||||
// Treat zero dim tensor as scalar value by extracting the first element
|
||||
// because PyTorch/ONNX uses zero dim tensor for scalar values
|
||||
match tensor.data.unwrap() {
|
||||
TensorData::Float32(val) => ConstantValue::Float32(val[0]),
|
||||
TensorData::Float64(val) => ConstantValue::Float64(val[0]),
|
||||
TensorData::Int32(val) => ConstantValue::Int32(val[0]),
|
||||
TensorData::Int64(val) => ConstantValue::Int64(val[0]),
|
||||
_ => panic!(
|
||||
"Unsupported zero dim constant tensor type: {:?} ",
|
||||
tensor.elem_type
|
||||
),
|
||||
}
|
||||
} else {
|
||||
let ds = match tensor.elem_type {
|
||||
ElementType::Float32 | ElementType::Float64 => TensorValue::Float(
|
||||
tensor.clone().into_data_serialize::<PS::FloatElem>(),
|
||||
),
|
||||
ElementType::Int32 | ElementType::Int64 => {
|
||||
TensorValue::Int(tensor.clone().into_data_serialize::<PS::IntElem>())
|
||||
}
|
||||
_ => panic!("Unsupported constant tensor type: {:?} ", tensor.elem_type),
|
||||
};
|
||||
|
||||
ConstantValue::<PS>::Tensor(
|
||||
TensorType::new(
|
||||
node.name.clone(),
|
||||
tensor.dim,
|
||||
tensor.elem_type.into(),
|
||||
tensor.shape,
|
||||
),
|
||||
ds,
|
||||
)
|
||||
}
|
||||
}
|
||||
_ => panic!("Unsupported constant value: {:?} ", value),
|
||||
};
|
||||
|
||||
ConstantNode::new(output.name.clone(), value)
|
||||
ConstantNode::new(node.name.clone(), value, output.to_type())
|
||||
}
|
||||
|
||||
fn add_conversion(node: Node) -> BinaryNode {
|
||||
// FIXME scalar vs tensor
|
||||
let lhs = node.inputs.get(0).unwrap().to_tensor_type();
|
||||
let rhs = node.inputs.get(1).unwrap().to_tensor_type();
|
||||
let output = node.outputs.get(0).unwrap().to_tensor_type();
|
||||
let lhs = node.inputs.get(0).unwrap().to_type();
|
||||
let rhs = node.inputs.get(1).unwrap().to_type();
|
||||
let output = node.outputs.get(0).unwrap().to_type();
|
||||
|
||||
BinaryNode::add(lhs, rhs, output)
|
||||
BinaryNode::add(lhs.clone(), rhs.clone(), output.clone())
|
||||
}
|
||||
|
||||
fn sub_conversion(node: Node) -> BinaryNode {
|
||||
// FIXME scalar vs tensor
|
||||
let lhs = node.inputs.get(0).unwrap().to_tensor_type();
|
||||
let rhs = node.inputs.get(1).unwrap().to_tensor_type();
|
||||
let output = node.outputs.get(0).unwrap().to_tensor_type();
|
||||
let lhs = node.inputs.get(0).unwrap().to_type();
|
||||
let rhs = node.inputs.get(1).unwrap().to_type();
|
||||
let output = node.outputs.get(0).unwrap().to_type();
|
||||
|
||||
BinaryNode::sub(lhs, rhs, output)
|
||||
}
|
||||
|
||||
fn mul_conversion(node: Node) -> BinaryNode {
|
||||
// FIXME scalar vs tensor
|
||||
let lhs = node.inputs.get(0).unwrap().to_tensor_type();
|
||||
let rhs = node.inputs.get(1).unwrap().to_tensor_type();
|
||||
let output = node.outputs.get(0).unwrap().to_tensor_type();
|
||||
let lhs = node.inputs.get(0).unwrap().to_type();
|
||||
let rhs = node.inputs.get(1).unwrap().to_type();
|
||||
let output = node.outputs.get(0).unwrap().to_type();
|
||||
|
||||
BinaryNode::mul(lhs, rhs, output)
|
||||
}
|
||||
|
||||
fn div_conversion(node: Node) -> BinaryNode {
|
||||
let lhs = node.inputs.get(0).unwrap().to_tensor_type();
|
||||
let rhs = node.inputs.get(1).unwrap().to_tensor_type();
|
||||
let output = node.outputs.get(0).unwrap().to_tensor_type();
|
||||
let lhs = node.inputs.get(0).unwrap().to_type();
|
||||
let rhs = node.inputs.get(1).unwrap().to_type();
|
||||
let output = node.outputs.get(0).unwrap().to_type();
|
||||
|
||||
BinaryNode::div(lhs, rhs, output)
|
||||
}
|
||||
|
@ -257,35 +321,42 @@ impl ONNXGraph {
|
|||
}
|
||||
|
||||
fn equal_conversion(node: Node) -> BinaryNode {
|
||||
let lhs = node.inputs.get(0).unwrap().to_tensor_type();
|
||||
let rhs = node.inputs.get(1).unwrap().to_tensor_type();
|
||||
let output = node.outputs.get(0).unwrap().to_tensor_type();
|
||||
let lhs = node.inputs.get(0).unwrap().to_type();
|
||||
let rhs = node.inputs.get(1).unwrap().to_type();
|
||||
let output = node.outputs.get(0).unwrap().to_type();
|
||||
|
||||
BinaryNode::equal(lhs, rhs, output)
|
||||
}
|
||||
|
||||
fn relu_conversion(node: Node) -> UnaryNode {
|
||||
let input = node.inputs.get(0).unwrap().to_tensor_type();
|
||||
let output = node.outputs.get(0).unwrap().to_tensor_type();
|
||||
let input = node.inputs.get(0).unwrap().to_type();
|
||||
let output = node.outputs.get(0).unwrap().to_type();
|
||||
|
||||
UnaryNode::relu(input, output)
|
||||
}
|
||||
|
||||
fn flatten_conversion(node: Node) -> UnaryNode {
|
||||
let input = node.inputs.get(0).unwrap().to_tensor_type();
|
||||
let output = node.outputs.get(0).unwrap().to_tensor_type();
|
||||
let input = node.inputs.get(0).unwrap().to_type();
|
||||
let output = node.outputs.get(0).unwrap().to_type();
|
||||
let (start_dim, end_dim) = flatten_config(&node);
|
||||
|
||||
UnaryNode::flatten(input, output, start_dim, end_dim)
|
||||
}
|
||||
|
||||
fn transpose_conversion(node: Node) -> UnaryNode {
|
||||
let input = node.inputs.get(0).unwrap().to_tensor_type();
|
||||
let output = node.outputs.get(0).unwrap().to_tensor_type();
|
||||
let input = node.inputs.get(0).unwrap().to_type();
|
||||
let output = node.outputs.get(0).unwrap().to_type();
|
||||
|
||||
UnaryNode::transpose(input, output)
|
||||
}
|
||||
|
||||
fn cast_conversion(node: Node) -> UnaryNode {
|
||||
let input = node.inputs.get(0).unwrap().to_type();
|
||||
let output = node.outputs.get(0).unwrap().to_type();
|
||||
|
||||
UnaryNode::cast(input, output)
|
||||
}
|
||||
|
||||
fn reshape_conversion(mut node: Node) -> ReshapeNode {
|
||||
let input = node.inputs.get(0).unwrap().to_tensor_type();
|
||||
let output = node.outputs.get(0).unwrap().to_tensor_type();
|
||||
|
@ -299,15 +370,15 @@ impl ONNXGraph {
|
|||
}
|
||||
|
||||
fn sigmoid_conversion(node: Node) -> UnaryNode {
|
||||
let input = node.inputs.get(0).unwrap().to_tensor_type();
|
||||
let output = node.outputs.get(0).unwrap().to_tensor_type();
|
||||
let input = node.inputs.get(0).unwrap().to_type();
|
||||
let output = node.outputs.get(0).unwrap().to_type();
|
||||
|
||||
UnaryNode::sigmoid(input, output)
|
||||
}
|
||||
|
||||
fn log_softmax_conversion(node: Node) -> UnaryNode {
|
||||
let input = node.inputs.get(0).unwrap().to_tensor_type();
|
||||
let output = node.outputs.get(0).unwrap().to_tensor_type();
|
||||
let input = node.inputs.get(0).unwrap().to_type();
|
||||
let output = node.outputs.get(0).unwrap().to_type();
|
||||
let dim = log_softmax_config(&node);
|
||||
|
||||
UnaryNode::log_softmax(input, output, dim)
|
||||
|
@ -419,8 +490,54 @@ impl Argument {
|
|||
pub fn to_tensor_type(&self) -> TensorType {
|
||||
match &self.ty {
|
||||
ArgType::Tensor(tensor) => TensorType::new_float(self.name.clone(), tensor.dim),
|
||||
_ => panic!("Can't transform to tensor."),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_type(&self) -> Type {
|
||||
match &self.ty {
|
||||
ArgType::Tensor(tensor) => {
|
||||
// Treat tensor with dim 0 as scalar
|
||||
if tensor.dim == 0 {
|
||||
// FIXME Convert to correct scalar type (@antimora 8/1/2023)
|
||||
// Currently it's not dangerous because we don't use specific scalar type
|
||||
Type::Scalar(ScalarType::new(self.name.clone(), ScalarKind::Float64))
|
||||
} else {
|
||||
Type::Tensor(TensorType::new_float(self.name.clone(), tensor.dim))
|
||||
}
|
||||
}
|
||||
|
||||
ArgType::Scalar(elem_type) => {
|
||||
Type::Scalar(ScalarType::new(self.name.clone(), elem_type.into()))
|
||||
}
|
||||
ArgType::Shape(_shape) => panic!("Can't transform shape to tensor."),
|
||||
ArgType::Constant => panic!("Can't transform constant to tensor."),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&ElementType> for ScalarKind {
|
||||
fn from(elem_type: &ElementType) -> Self {
|
||||
match elem_type {
|
||||
ElementType::Float32 => ScalarKind::Float32,
|
||||
ElementType::Float64 => ScalarKind::Float64,
|
||||
ElementType::Int32 => ScalarKind::Int32,
|
||||
ElementType::Int64 => ScalarKind::Int64,
|
||||
ElementType::Bool => ScalarKind::Bool,
|
||||
ElementType::String => panic!("String tensor unsupported"),
|
||||
ElementType::Float16 => panic!("Float16 tensor unsupported"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ElementType> for TensorKind {
|
||||
fn from(elem_type: ElementType) -> Self {
|
||||
match elem_type {
|
||||
ElementType::Float32 => TensorKind::Float,
|
||||
ElementType::Float64 => TensorKind::Float,
|
||||
ElementType::Int32 => TensorKind::Int,
|
||||
ElementType::Int64 => TensorKind::Int,
|
||||
ElementType::Bool => TensorKind::Bool,
|
||||
_ => panic!("Unsupported tensor type"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,31 +0,0 @@
|
|||
// Generated by integration tests
|
||||
use burn::{
|
||||
module::Module,
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Model<B: Backend> {}
|
||||
|
||||
impl<B: Backend> Model<B> {
|
||||
pub fn new_with(record: ModelRecord<B>) -> Self {
|
||||
Self {}
|
||||
}
|
||||
|
||||
#[allow(clippy::let_and_return)]
|
||||
pub fn forward(&self, input1: Tensor<B, 4>, input1: Tensor<B, 4>) -> Tensor<B, 4> {
|
||||
let concat1_out1 = burn::tensor::Tensor::cat(vec![input1.clone(), input1.clone()], 1);
|
||||
let concat2_out1 = burn::tensor::Tensor::cat(
|
||||
vec![
|
||||
input1.clone(),
|
||||
concat1_out1.clone(),
|
||||
concat1_out1.clone(),
|
||||
concat1_out1.clone(),
|
||||
concat1_out1,
|
||||
],
|
||||
1,
|
||||
);
|
||||
concat2_out1
|
||||
}
|
||||
}
|
Binary file not shown.
|
@ -1,36 +0,0 @@
|
|||
# used to generate model: burn-import/tests/data/conv2d/conv2d.onnx
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import onnx
|
||||
from onnxoptimizer import optimize
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
self.conv1 = nn.Conv2d(16, 36, (3, 5), groups = 2, stride=(2, 1), padding=(4, 2), dilation=(3, 1))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
return x
|
||||
|
||||
def main():
|
||||
|
||||
# Export to onnx
|
||||
model = Model()
|
||||
model.eval()
|
||||
device = torch.device("cpu")
|
||||
dummy_input = torch.randn(20, 16, 50, 100, device=device)
|
||||
torch.onnx.export(model, dummy_input, "conv2d.onnx",
|
||||
verbose=False, opset_version=16)
|
||||
|
||||
# Apply the optimization pass to simplify the model
|
||||
onnx_model = onnx.load("conv2d.onnx")
|
||||
optimized_model = optimize(onnx_model)
|
||||
|
||||
# Save the optimized model
|
||||
onnx.save(optimized_model, "conv2d.onnx")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -1,17 +0,0 @@
|
|||
# Model1 test data files
|
||||
|
||||
This directory contains the test data for the model1 test. The test data is generated by running the
|
||||
following command:
|
||||
|
||||
```bash
|
||||
python3 model1.py
|
||||
cargo run model1.onnx ./
|
||||
```
|
||||
|
||||
The following files are generated:
|
||||
|
||||
- `model1.onnx`: The ONNX model
|
||||
- `model1.rs`: The generated Rust code for the model (the path in the comment needs to be fixed for
|
||||
the test)
|
||||
- `model1.json`: The data of the model
|
||||
- `model1.graph.txt`: The IR of the model
|
Binary file not shown.
|
@ -1,46 +0,0 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import onnx
|
||||
from onnxoptimizer import optimize
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
self.conv1 = nn.Conv2d(1, 8, 3)
|
||||
self.norm1 = nn.BatchNorm2d(8)
|
||||
self.fc1 = nn.Linear(8*6*6, 10)
|
||||
self.norm2 = nn.BatchNorm1d(10)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = F.relu(x)
|
||||
x = self.norm1(x)
|
||||
x = torch.flatten(x, 1)
|
||||
x = self.fc1(x)
|
||||
x = self.norm2(x)
|
||||
output = F.log_softmax(x, dim=1)
|
||||
return output
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
# Export to onnx
|
||||
model = Model()
|
||||
model.eval()
|
||||
device = torch.device("cpu")
|
||||
dummy_input = torch.randn(1, 1, 8, 8, device=device)
|
||||
torch.onnx.export(model, dummy_input, "model1.onnx",
|
||||
verbose=False, opset_version=16)
|
||||
|
||||
# Apply the optimization pass to simplify the model
|
||||
onnx_model = onnx.load("model1.onnx")
|
||||
optimized_model = optimize(onnx_model)
|
||||
|
||||
# Save the optimized model
|
||||
onnx.save(optimized_model, "model1.onnx")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -1,62 +0,0 @@
|
|||
// Generated by integration tests
|
||||
use burn::nn::conv::Conv2d;
|
||||
use burn::nn::conv::Conv2dConfig;
|
||||
use burn::nn::BatchNorm;
|
||||
use burn::nn::BatchNormConfig;
|
||||
use burn::nn::Linear;
|
||||
use burn::nn::LinearConfig;
|
||||
use burn::nn::PaddingConfig2d;
|
||||
use burn::{
|
||||
module::Module,
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Model<B: Backend> {
|
||||
conv2d1: Conv2d<B>,
|
||||
batchnormalization1: BatchNorm<B, 2>,
|
||||
linear1: Linear<B>,
|
||||
batchnormalization2: BatchNorm<B, 0>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Model<B> {
|
||||
pub fn new_with(record: ModelRecord<B>) -> Self {
|
||||
let conv2d1 = Conv2dConfig::new([1, 8], [3, 3])
|
||||
.with_stride([1, 1])
|
||||
.with_padding(PaddingConfig2d::Valid)
|
||||
.with_dilation([1, 1])
|
||||
.with_groups(1)
|
||||
.with_bias(true)
|
||||
.init_with(record.conv2d1);
|
||||
let batchnormalization1 = BatchNormConfig::new(8)
|
||||
.with_epsilon(0.000009999999747378752f64)
|
||||
.with_momentum(0.8999999761581421f64)
|
||||
.init_with(record.batchnormalization1);
|
||||
let linear1 = LinearConfig::new(288, 10)
|
||||
.with_bias(true)
|
||||
.init_with(record.linear1);
|
||||
let batchnormalization2 = BatchNormConfig::new(10)
|
||||
.with_epsilon(0.000009999999747378752f64)
|
||||
.with_momentum(0.8999999761581421f64)
|
||||
.init_with(record.batchnormalization2);
|
||||
Self {
|
||||
conv2d1,
|
||||
batchnormalization1,
|
||||
linear1,
|
||||
batchnormalization2,
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::let_and_return)]
|
||||
pub fn forward(&self, input1: Tensor<B, 4>) -> Tensor<B, 2> {
|
||||
let conv2d1_out1 = self.conv2d1.forward(input1);
|
||||
let relu1_out1 = burn::tensor::activation::relu(conv2d1_out1);
|
||||
let batchnormalization1_out1 = self.batchnormalization1.forward(relu1_out1);
|
||||
let flatten1_out1 = batchnormalization1_out1.flatten(1, 3);
|
||||
let linear1_out1 = self.linear1.forward(flatten1_out1);
|
||||
let batchnormalization2_out1 = self.batchnormalization2.forward(linear1_out1);
|
||||
let logsoftmax1_out1 = burn::tensor::activation::log_softmax(batchnormalization2_out1, 1);
|
||||
logsoftmax1_out1
|
||||
}
|
||||
}
|
|
@ -1,42 +0,0 @@
|
|||
#[cfg(test)]
|
||||
#[cfg(feature = "onnx")]
|
||||
mod tests {
|
||||
use std::fs::read_to_string;
|
||||
use std::path::Path;
|
||||
|
||||
use burn::record::FullPrecisionSettings;
|
||||
use pretty_assertions::assert_eq;
|
||||
use rstest::*;
|
||||
|
||||
fn code<P: AsRef<Path>>(onnx_path: P) -> String {
|
||||
let graph = burn_import::onnx::parse_onnx(onnx_path.as_ref());
|
||||
let graph = graph
|
||||
.into_burn::<FullPrecisionSettings>()
|
||||
.with_blank_space(true)
|
||||
.with_top_comment(Some("Generated by integration tests".into()));
|
||||
|
||||
burn_import::format_tokens(graph.codegen())
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
#[case::mixed("model1")]
|
||||
#[case::conv2d("conv2d")]
|
||||
#[case::concat("concat")]
|
||||
// #[case::description_here("model2")] <- Add more models here
|
||||
fn test_codegen(#[case] model_name: &str) {
|
||||
let input_file = format!("tests/data/{model_name}/{model_name}.onnx");
|
||||
let source_file = format!("tests/data/{model_name}/{model_name}.rs");
|
||||
let source_expected: String =
|
||||
read_to_string(source_file).expect("Expected source file is missing");
|
||||
|
||||
let generated_code = code(input_file);
|
||||
|
||||
// Uncomment this to update the expected code
|
||||
// println!("Generated code:\n{}", generated_code);
|
||||
|
||||
assert_eq!(
|
||||
source_expected, generated_code,
|
||||
"Expected code is left, actual code is right"
|
||||
);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue