mirror of https://github.com/tracel-ai/burn.git
Contributor Book: Onnx to Burn Conversion (#1771)
* Start on steps (1 & 2) for implementing a new operator for onnx conversion * Add more steps to process * Touch up * Fix dimension inferencing instructions * Fix numbering and other small stuff * Update w/links * Add a warning about dimension changes * Minor link touch-ups and wording * Add a note on unary/binary operations
This commit is contained in:
parent
81ecd14f83
commit
6b0673d4bd
|
@ -16,6 +16,7 @@ For an introduction to ONNX import in Burn, see
|
|||
- [Design Goals](#design-goals)
|
||||
- [Design Decisions](#design-decisions)
|
||||
- [Adding New Operators](#adding-new-operators)
|
||||
- [Implementing a New Operator](#implementing-a-new-operator)
|
||||
- [Testing](#testing)
|
||||
- [Resources](#resources)
|
||||
|
||||
|
@ -63,7 +64,8 @@ To extend `burn-import` with support for new ONNX operators, follow these steps:
|
|||
```
|
||||
|
||||
5. **Implement Missing Operators**: If you encounter an error stating that an operator is
|
||||
unsupported, implement it. The `./out/my-model.graph.txt` should provide relevant information.
|
||||
unsupported, [implement it](#implementing-a-new-operator). The `./out/my-model.graph.txt` should
|
||||
provide relevant information.
|
||||
|
||||
6. **Inspect Generated Files**: The `my-model.graph.txt` contains IR details, `my-model.rs` holds
|
||||
the Burn model in Rust code, and `my-model.json` includes the model data.
|
||||
|
@ -73,6 +75,201 @@ To extend `burn-import` with support for new ONNX operators, follow these steps:
|
|||
Further details can be found in the
|
||||
[onnx-tests README](https://github.com/tracel-ai/burn/blob/6d96e8d8086d2309c425f2c8a43a8246f8c454d2/crates/burn-import/onnx-tests/README.md).
|
||||
|
||||
## Implementing a New Operator
|
||||
|
||||
To extend the capabilities of the Burn library by supporting new operations imported from ONNX
|
||||
graphs, developers must go through a few systematic steps. Here, we detail the process, using the
|
||||
implementation of the `Squeeze` operation to illustrate points as needed. All file/directory paths
|
||||
are relative to `burn/crates/burn-import/`.
|
||||
|
||||
### Step 1: Visibility
|
||||
|
||||
To make a new operation accessible to the rest of the Burn project, you need to declare the module
|
||||
within the
|
||||
[`mod.rs` file](https://github.com/tracel-ai/burn/blob/9c5b07c833865bff7f82431001076a33d0d8729c/crates/burn-import/src/burn/node/mod.rs#L24)
|
||||
located in the `src/burn/node/` directory.
|
||||
|
||||
### Step 2: Node Implementation
|
||||
|
||||
Create a new file named `<operation_name>.rs` in the `src/burn/node/` directory.
|
||||
This file will define the structure and functionality of your new operation. By convention, the
|
||||
necessary information for carrying out an operation is encapsulated within a struct named
|
||||
`<operation>Node`. For the `Squeeze` operation, we defined a
|
||||
[struct called `SqueezeNode`](https://github.com/tracel-ai/burn/blob/9c5b07c833865bff7f82431001076a33d0d8729c/crates/burn-import/src/burn/node/squeeze.rs#L8)
|
||||
that holds necessary information about the input tensor, output tensor, and axes for the operation.
|
||||
**If implementing a unary or binary operation, please see note below.**
|
||||
|
||||
The core of integrating a new operation involves implementing the `NodeCodegen` trait for your node.
|
||||
This trait defines how the node generates code during the graph compilation process. The
|
||||
implementation must provide methods to define input and output types, to generate the forward pass
|
||||
code, and to encapsulate the node into the more general `Node` structure. Specifically:
|
||||
|
||||
- `output_types` and `input_types` return the tensor (or element) types for the output and inputs of
|
||||
the node, respectively.
|
||||
- `forward` generates the Rust code that performs the operation during the execution phase. The
|
||||
`quote!` macro is used to generate rust code. Ensure that this is syntactically correct using Burn
|
||||
code.
|
||||
- `into_node` wraps the specific node in a general `Node` type, facilitating its inclusion in the
|
||||
broader Burn graph structure.
|
||||
|
||||
This file is also where you would put `test_codegen_nodes()`, to make sure that the generated code
|
||||
works within the Burn library.
|
||||
|
||||
**For unary and binary operations:** The implementation of `NodeCodegen` is mostly implemented in
|
||||
[`binary.rs`](https://github.com/tracel-ai/burn/blob/76fe0ed881b3965782f78896433f8bb5e2f13a1b/crates/burn-import/src/burn/node/binary.rs#L9)
|
||||
and
|
||||
[`unary.rs`](https://github.com/tracel-ai/burn/blob/76fe0ed881b3965782f78896433f8bb5e2f13a1b/crates/burn-import/src/burn/node/unary.rs#L13),
|
||||
so each new operation only has to define a method to execute the function on the input(s) token
|
||||
stream.
|
||||
|
||||
### Step 3: Registering New Operations
|
||||
|
||||
[Register the `NodeType::<operation>`](https://github.com/tracel-ai/burn/blob/9c5b07c833865bff7f82431001076a33d0d8729c/crates/burn-import/src/onnx/to_burn.rs#L293)
|
||||
and
|
||||
[create an `<operation>_conversion(node: Node)` function](https://github.com/tracel-ai/burn/blob/9c5b07c833865bff7f82431001076a33d0d8729c/crates/burn-import/src/onnx/to_burn.rs#L831),
|
||||
both in `src/onnx/to_burn.rs`.
|
||||
|
||||
**Registering new operations in the ONNX -> Burn Conversion**
|
||||
To integrate new operations from an ONNX graph into the Burn framework, each operation must be
|
||||
registered within the ONNX graph conversion process. This is done in the `src/onnx/to_burn.rs` file,
|
||||
where the conversion from ONNX nodes to Burn nodes is orchestrated.
|
||||
|
||||
In the `into_burn()` method of the `OnnxGraph` struct, operations are matched with their
|
||||
corresponding conversion functions. This method iterates over each node in the ONNX graph and,
|
||||
depending on the node type, calls a specific conversion function that translates the ONNX node into
|
||||
a corresponding Burn node.
|
||||
|
||||
```rust
|
||||
impl OnnxGraph {
|
||||
pub fn into_burn<PS: PrecisionSettings + 'static>(self) -> BurnGraph<PS> {
|
||||
let mut graph = BurnGraph::<PS>::default();
|
||||
let mut unsupported_ops = vec![];
|
||||
|
||||
for node in self.nodes {
|
||||
match node.node_type {
|
||||
NodeType::Add => graph.register(Self::add_conversion(node)),
|
||||
// Other operations...
|
||||
NodeType::Squeeze => graph.register(Self::squeeze_conversion(node)),
|
||||
// Add new operations here
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Here, the `NodeType::Squeeze` matches the ONNX node type with the `squeeze_conversion()` function
|
||||
that you define to handle the specific attributes and settings of a Squeeze operation.
|
||||
|
||||
**Define the Conversion Function**
|
||||
Each operation conversion function extracts necessary information from the ONNX node and constructs
|
||||
a corresponding Burn node. The structure of these functions generally includes:
|
||||
|
||||
1. Extracting input and output tensors from the node.
|
||||
2. Retrieving and processing operation-specific configurations.
|
||||
3. Calling `<operation>_config()` to parse ONNX node configurations.
|
||||
4. Creating an instance of the appropriate Burn node
|
||||
([defined in step 2](#step-2-node-implementation)) using this information.
|
||||
|
||||
### Step 4: Create a Config Function
|
||||
|
||||
[Create an `<operation>_config(curr: &Node)`](https://github.com/tracel-ai/burn/blob/9c5b07c833865bff7f82431001076a33d0d8729c/crates/burn-import/src/onnx/op_configuration.rs#L975)
|
||||
in `src/onnx/op_configuration.rs`.
|
||||
|
||||
The `squeeze_conversion()` function in `src/onnx/to_burn.rs` from the previous step calls the
|
||||
`squeeze_config()` function in `src/onnx/op_configuration.rs` in order the parse the ONNX node's
|
||||
attributes to extract parameters specific to the Squeeze operation. In this case, the axes along
|
||||
which the squeeze operation is performed.
|
||||
|
||||
> 📘 Info: Understanding Generic `config` Patterns
|
||||
>
|
||||
> The `<op>_config()` functions follow a similar pattern:
|
||||
>
|
||||
> 1. Extract tensor or scalar types for inputs and outputs.
|
||||
> 2. Validate the input structure and types for each node, ensuring they conform to expected formats
|
||||
> (panicking if not).
|
||||
> 3. Parse and convert configurations or parameters specific to each operation.
|
||||
> 4. Create and return a node specific to the operation, initialized with extracted values and
|
||||
> configurations.
|
||||
>
|
||||
> For example, config functions handle specific settings like kernel size for pooling or handling
|
||||
> different tensor and scalar types for power operations.
|
||||
|
||||
These functions translate the more varied and flexible structure of ONNX nodes into the more
|
||||
structured and type-safe environment of Rust and the Burn framework. Spec compliance is dealt with
|
||||
here.
|
||||
|
||||
### Step 5: Dimension Inference
|
||||
|
||||
If needed,
|
||||
[create a dimension inference function](https://github.com/tracel-ai/burn/blob/9c5b07c833865bff7f82431001076a33d0d8729c/crates/burn-import/src/onnx/dim_inference.rs#L271),
|
||||
called `<operation>_update_output(node: &mut Node)` in `src/onnx/dim_inference.rs`. If dimensions
|
||||
remain unchanged, use the `same_as_input()` function, for example
|
||||
`NodeType::AveragePool1d => same_as_input(node)`. Match the `NodeType` to the function in the
|
||||
`dim_inference()` match block.
|
||||
|
||||
Dimension inference is an important step in the conversion process where Burn determines the
|
||||
dimensions of each output tensor based on the operation.
|
||||
[The `dim_inference()`](https://github.com/tracel-ai/burn/blob/9c5b07c833865bff7f82431001076a33d0d8729c/crates/burn-import/src/onnx/dim_inference.rs#L14)
|
||||
function is responsible for determining the dimensions of the output tensors for each node in the
|
||||
graph. It does this by:
|
||||
|
||||
1. **Matching the Node Type**: The function uses a `match` statement on the `node_type` of each node
|
||||
to apply the correct dimension inference logic depending on the operation.
|
||||
2. **Applying Operation Specific Logic**: For each operation, a specific inference function is
|
||||
called that encapsulate the rules for how output dimensions should be derived from the inputs.
|
||||
|
||||
For the Squeeze operation, the dimension inference is handled by the `squeeze_update_output()`
|
||||
function, which is specifically tailored to handle the nuances of the squeeze operation, which is
|
||||
currently not that nuanced. The output tensor should be (dimensions of input tensor) - 1.
|
||||
|
||||
> 📘 Info: How `squeeze_update_output()` Works
|
||||
>
|
||||
> 1. Validation of axes input: We first check if the second input of the node contains a list of
|
||||
> integers, which represent the axes along which the squeeze operation is applied. The function
|
||||
> also validates that only one axis is specified for squeezing, ensuring that the operation's
|
||||
> requirements within Burn are followed.
|
||||
> 2. Extracting input dimensions: The input tensor's dimension is extracted from the first input.
|
||||
> 3. Configuring output dimensions: The output tensor's dimensions are then set to be one less than
|
||||
> the input tensor’s dimensions, reflecting the reduction in dimensions caused by the squeeze
|
||||
> operation.
|
||||
> 4. The function includes several checks that throw errors (panics) if the inputs do not meet the
|
||||
> expected types or configurations, such as when the axes are not provided as an integer list or
|
||||
> if the input type is not a tensor.
|
||||
|
||||
By invoking this function within the `dim_inference()` match block, the output dimensions of each
|
||||
node are updated before the graph is finalized. This ensures that all subsequent operations within
|
||||
the graph can rely on correct tensor sizes, which is critical for both compiling the graph and for
|
||||
runtime execution efficiency.
|
||||
|
||||
If something is amiss (ie weird panics are happening), after doing this step and the dimensions of
|
||||
your output tensor differs from the dimensions of your input, see the warning at the very end.
|
||||
|
||||
### Step 6: Integrate into the Graph Building Process
|
||||
|
||||
When a new node type is introduced, it must be added to the
|
||||
[`Node<PS: PrecisionSettings>` enum](https://github.com/tracel-ai/burn/blob/9c5b07c833865bff7f82431001076a33d0d8729c/crates/burn-import/src/burn/node/base.rs#L77)
|
||||
and
|
||||
[`match_all!` macro](https://github.com/tracel-ai/burn/blob/9c5b07c833865bff7f82431001076a33d0d8729c/crates/burn-import/src/burn/node/base.rs#L104)
|
||||
in `src/burn/node/base.rs`.
|
||||
|
||||
The `Node` enum abstracts over different types of operations (nodes) within a network graph. Each
|
||||
variant of the enum corresponds to a specific type of operation, and it encapsulates the
|
||||
operation-specific data structures (like `SqueezeNode1`) that was
|
||||
[defined in step 2](#step-2-node-implementation).
|
||||
|
||||
### Step 7: Add Newly Supported Op!
|
||||
|
||||
As a reward, add an extra check to
|
||||
[SUPPORTED-ONNX-OPS.md](https://github.com/tracel-ai/burn/blob/9c5b07c833865bff7f82431001076a33d0d8729c/crates/burn-import/SUPPORTED-ONNX-OPS.md?plain=1#L1)!
|
||||
|
||||
### Misc:
|
||||
|
||||
> 🚧 **Warning**: Dimension Changes
|
||||
>
|
||||
> If your operation changes the dimensions of the input tensor, you may need to modify the
|
||||
> [`LIFT_CONSTANTS_FOR_NODE_TYPES` enum](https://github.com/tracel-ai/burn/blob/9c5b07c833865bff7f82431001076a33d0d8729c/crates/burn-import/src/onnx/from_onnx.rs#L20)
|
||||
> in `src/onnx/from_onnx.rs` by adding the `NodeType` of your operation to it.
|
||||
|
||||
## Testing
|
||||
|
||||
- Unit tests for the Burn graph to Rust source code conversion are mandatory.
|
||||
|
|
Loading…
Reference in New Issue