Contributor Book: Onnx to Burn Conversion (#1771)

* Start on steps (1 & 2) for implementing a new operator for onnx conversion

* Add more steps to process

* Touch up

* Fix dimension inferencing instructions

* Fix numbering and other small stuff

* Update w/links

* Add a warning about dimension changes

* Minor link touch-ups and wording

* Add a note on unary/binary operations
This commit is contained in:
Mathias Insley 2024-05-22 05:30:11 -07:00 committed by GitHub
parent 81ecd14f83
commit 6b0673d4bd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 198 additions and 1 deletions

View File

@ -16,6 +16,7 @@ For an introduction to ONNX import in Burn, see
- [Design Goals](#design-goals)
- [Design Decisions](#design-decisions)
- [Adding New Operators](#adding-new-operators)
- [Implementing a New Operator](#implementing-a-new-operator)
- [Testing](#testing)
- [Resources](#resources)
@ -63,7 +64,8 @@ To extend `burn-import` with support for new ONNX operators, follow these steps:
```
5. **Implement Missing Operators**: If you encounter an error stating that an operator is
unsupported, implement it. The `./out/my-model.graph.txt` should provide relevant information.
unsupported, [implement it](#implementing-a-new-operator). The `./out/my-model.graph.txt` should
provide relevant information.
6. **Inspect Generated Files**: The `my-model.graph.txt` contains IR details, `my-model.rs` holds
the Burn model in Rust code, and `my-model.json` includes the model data.
@ -73,6 +75,201 @@ To extend `burn-import` with support for new ONNX operators, follow these steps:
Further details can be found in the
[onnx-tests README](https://github.com/tracel-ai/burn/blob/6d96e8d8086d2309c425f2c8a43a8246f8c454d2/crates/burn-import/onnx-tests/README.md).
## Implementing a New Operator
To extend the capabilities of the Burn library by supporting new operations imported from ONNX
graphs, developers must go through a few systematic steps. Here, we detail the process, using the
implementation of the `Squeeze` operation to illustrate points as needed. All file/directory paths
are relative to `burn/crates/burn-import/`.
### Step 1: Visibility
To make a new operation accessible to the rest of the Burn project, you need to declare the module
within the
[`mod.rs` file](https://github.com/tracel-ai/burn/blob/9c5b07c833865bff7f82431001076a33d0d8729c/crates/burn-import/src/burn/node/mod.rs#L24)
located in the `src/burn/node/` directory.
### Step 2: Node Implementation
Create a new file named `<operation_name>.rs` in the `src/burn/node/` directory.
This file will define the structure and functionality of your new operation. By convention, the
necessary information for carrying out an operation is encapsulated within a struct named
`<operation>Node`. For the `Squeeze` operation, we defined a
[struct called `SqueezeNode`](https://github.com/tracel-ai/burn/blob/9c5b07c833865bff7f82431001076a33d0d8729c/crates/burn-import/src/burn/node/squeeze.rs#L8)
that holds necessary information about the input tensor, output tensor, and axes for the operation.
**If implementing a unary or binary operation, please see note below.**
The core of integrating a new operation involves implementing the `NodeCodegen` trait for your node.
This trait defines how the node generates code during the graph compilation process. The
implementation must provide methods to define input and output types, to generate the forward pass
code, and to encapsulate the node into the more general `Node` structure. Specifically:
- `output_types` and `input_types` return the tensor (or element) types for the output and inputs of
the node, respectively.
- `forward` generates the Rust code that performs the operation during the execution phase. The
`quote!` macro is used to generate rust code. Ensure that this is syntactically correct using Burn
code.
- `into_node` wraps the specific node in a general `Node` type, facilitating its inclusion in the
broader Burn graph structure.
This file is also where you would put `test_codegen_nodes()`, to make sure that the generated code
works within the Burn library.
**For unary and binary operations:** The implementation of `NodeCodegen` is mostly implemented in
[`binary.rs`](https://github.com/tracel-ai/burn/blob/76fe0ed881b3965782f78896433f8bb5e2f13a1b/crates/burn-import/src/burn/node/binary.rs#L9)
and
[`unary.rs`](https://github.com/tracel-ai/burn/blob/76fe0ed881b3965782f78896433f8bb5e2f13a1b/crates/burn-import/src/burn/node/unary.rs#L13),
so each new operation only has to define a method to execute the function on the input(s) token
stream.
### Step 3: Registering New Operations
[Register the `NodeType::<operation>`](https://github.com/tracel-ai/burn/blob/9c5b07c833865bff7f82431001076a33d0d8729c/crates/burn-import/src/onnx/to_burn.rs#L293)
and
[create an `<operation>_conversion(node: Node)` function](https://github.com/tracel-ai/burn/blob/9c5b07c833865bff7f82431001076a33d0d8729c/crates/burn-import/src/onnx/to_burn.rs#L831),
both in `src/onnx/to_burn.rs`.
**Registering new operations in the ONNX -> Burn Conversion**
To integrate new operations from an ONNX graph into the Burn framework, each operation must be
registered within the ONNX graph conversion process. This is done in the `src/onnx/to_burn.rs` file,
where the conversion from ONNX nodes to Burn nodes is orchestrated.
In the `into_burn()` method of the `OnnxGraph` struct, operations are matched with their
corresponding conversion functions. This method iterates over each node in the ONNX graph and,
depending on the node type, calls a specific conversion function that translates the ONNX node into
a corresponding Burn node.
```rust
impl OnnxGraph {
pub fn into_burn<PS: PrecisionSettings + 'static>(self) -> BurnGraph<PS> {
let mut graph = BurnGraph::<PS>::default();
let mut unsupported_ops = vec![];
for node in self.nodes {
match node.node_type {
NodeType::Add => graph.register(Self::add_conversion(node)),
// Other operations...
NodeType::Squeeze => graph.register(Self::squeeze_conversion(node)),
// Add new operations here
}
}
}
}
```
Here, the `NodeType::Squeeze` matches the ONNX node type with the `squeeze_conversion()` function
that you define to handle the specific attributes and settings of a Squeeze operation.
**Define the Conversion Function**
Each operation conversion function extracts necessary information from the ONNX node and constructs
a corresponding Burn node. The structure of these functions generally includes:
1. Extracting input and output tensors from the node.
2. Retrieving and processing operation-specific configurations.
3. Calling `<operation>_config()` to parse ONNX node configurations.
4. Creating an instance of the appropriate Burn node
([defined in step 2](#step-2-node-implementation)) using this information.
### Step 4: Create a Config Function
[Create an `<operation>_config(curr: &Node)`](https://github.com/tracel-ai/burn/blob/9c5b07c833865bff7f82431001076a33d0d8729c/crates/burn-import/src/onnx/op_configuration.rs#L975)
in `src/onnx/op_configuration.rs`.
The `squeeze_conversion()` function in `src/onnx/to_burn.rs` from the previous step calls the
`squeeze_config()` function in `src/onnx/op_configuration.rs` in order the parse the ONNX node's
attributes to extract parameters specific to the Squeeze operation. In this case, the axes along
which the squeeze operation is performed.
> 📘 Info: Understanding Generic `config` Patterns
>
> The `<op>_config()` functions follow a similar pattern:
>
> 1. Extract tensor or scalar types for inputs and outputs.
> 2. Validate the input structure and types for each node, ensuring they conform to expected formats
> (panicking if not).
> 3. Parse and convert configurations or parameters specific to each operation.
> 4. Create and return a node specific to the operation, initialized with extracted values and
> configurations.
>
> For example, config functions handle specific settings like kernel size for pooling or handling
> different tensor and scalar types for power operations.
These functions translate the more varied and flexible structure of ONNX nodes into the more
structured and type-safe environment of Rust and the Burn framework. Spec compliance is dealt with
here.
### Step 5: Dimension Inference
If needed,
[create a dimension inference function](https://github.com/tracel-ai/burn/blob/9c5b07c833865bff7f82431001076a33d0d8729c/crates/burn-import/src/onnx/dim_inference.rs#L271),
called `<operation>_update_output(node: &mut Node)` in `src/onnx/dim_inference.rs`. If dimensions
remain unchanged, use the `same_as_input()` function, for example
`NodeType::AveragePool1d => same_as_input(node)`. Match the `NodeType` to the function in the
`dim_inference()` match block.
Dimension inference is an important step in the conversion process where Burn determines the
dimensions of each output tensor based on the operation.
[The `dim_inference()`](https://github.com/tracel-ai/burn/blob/9c5b07c833865bff7f82431001076a33d0d8729c/crates/burn-import/src/onnx/dim_inference.rs#L14)
function is responsible for determining the dimensions of the output tensors for each node in the
graph. It does this by:
1. **Matching the Node Type**: The function uses a `match` statement on the `node_type` of each node
to apply the correct dimension inference logic depending on the operation.
2. **Applying Operation Specific Logic**: For each operation, a specific inference function is
called that encapsulate the rules for how output dimensions should be derived from the inputs.
For the Squeeze operation, the dimension inference is handled by the `squeeze_update_output()`
function, which is specifically tailored to handle the nuances of the squeeze operation, which is
currently not that nuanced. The output tensor should be (dimensions of input tensor) - 1.
> 📘 Info: How `squeeze_update_output()` Works
>
> 1. Validation of axes input: We first check if the second input of the node contains a list of
> integers, which represent the axes along which the squeeze operation is applied. The function
> also validates that only one axis is specified for squeezing, ensuring that the operation's
> requirements within Burn are followed.
> 2. Extracting input dimensions: The input tensor's dimension is extracted from the first input.
> 3. Configuring output dimensions: The output tensor's dimensions are then set to be one less than
> the input tensors dimensions, reflecting the reduction in dimensions caused by the squeeze
> operation.
> 4. The function includes several checks that throw errors (panics) if the inputs do not meet the
> expected types or configurations, such as when the axes are not provided as an integer list or
> if the input type is not a tensor.
By invoking this function within the `dim_inference()` match block, the output dimensions of each
node are updated before the graph is finalized. This ensures that all subsequent operations within
the graph can rely on correct tensor sizes, which is critical for both compiling the graph and for
runtime execution efficiency.
If something is amiss (ie weird panics are happening), after doing this step and the dimensions of
your output tensor differs from the dimensions of your input, see the warning at the very end.
### Step 6: Integrate into the Graph Building Process
When a new node type is introduced, it must be added to the
[`Node<PS: PrecisionSettings>` enum](https://github.com/tracel-ai/burn/blob/9c5b07c833865bff7f82431001076a33d0d8729c/crates/burn-import/src/burn/node/base.rs#L77)
and
[`match_all!` macro](https://github.com/tracel-ai/burn/blob/9c5b07c833865bff7f82431001076a33d0d8729c/crates/burn-import/src/burn/node/base.rs#L104)
in `src/burn/node/base.rs`.
The `Node` enum abstracts over different types of operations (nodes) within a network graph. Each
variant of the enum corresponds to a specific type of operation, and it encapsulates the
operation-specific data structures (like `SqueezeNode1`) that was
[defined in step 2](#step-2-node-implementation).
### Step 7: Add Newly Supported Op!
As a reward, add an extra check to
[SUPPORTED-ONNX-OPS.md](https://github.com/tracel-ai/burn/blob/9c5b07c833865bff7f82431001076a33d0d8729c/crates/burn-import/SUPPORTED-ONNX-OPS.md?plain=1#L1)!
### Misc:
> 🚧 **Warning**: Dimension Changes
>
> If your operation changes the dimensions of the input tensor, you may need to modify the
> [`LIFT_CONSTANTS_FOR_NODE_TYPES` enum](https://github.com/tracel-ai/burn/blob/9c5b07c833865bff7f82431001076a33d0d8729c/crates/burn-import/src/onnx/from_onnx.rs#L20)
> in `src/onnx/from_onnx.rs` by adding the `NodeType` of your operation to it.
## Testing
- Unit tests for the Burn graph to Rust source code conversion are mandatory.