diff --git a/contributor-book/src/guides/onnx-to-burn-conversion-tool.md b/contributor-book/src/guides/onnx-to-burn-conversion-tool.md index 663498919..d9d31282c 100644 --- a/contributor-book/src/guides/onnx-to-burn-conversion-tool.md +++ b/contributor-book/src/guides/onnx-to-burn-conversion-tool.md @@ -16,6 +16,7 @@ For an introduction to ONNX import in Burn, see - [Design Goals](#design-goals) - [Design Decisions](#design-decisions) - [Adding New Operators](#adding-new-operators) + - [Implementing a New Operator](#implementing-a-new-operator) - [Testing](#testing) - [Resources](#resources) @@ -63,7 +64,8 @@ To extend `burn-import` with support for new ONNX operators, follow these steps: ``` 5. **Implement Missing Operators**: If you encounter an error stating that an operator is - unsupported, implement it. The `./out/my-model.graph.txt` should provide relevant information. + unsupported, [implement it](#implementing-a-new-operator). The `./out/my-model.graph.txt` should + provide relevant information. 6. **Inspect Generated Files**: The `my-model.graph.txt` contains IR details, `my-model.rs` holds the Burn model in Rust code, and `my-model.json` includes the model data. @@ -73,6 +75,201 @@ To extend `burn-import` with support for new ONNX operators, follow these steps: Further details can be found in the [onnx-tests README](https://github.com/tracel-ai/burn/blob/6d96e8d8086d2309c425f2c8a43a8246f8c454d2/crates/burn-import/onnx-tests/README.md). +## Implementing a New Operator + +To extend the capabilities of the Burn library by supporting new operations imported from ONNX +graphs, developers must go through a few systematic steps. Here, we detail the process, using the +implementation of the `Squeeze` operation to illustrate points as needed. All file/directory paths +are relative to `burn/crates/burn-import/`. + +### Step 1: Visibility + +To make a new operation accessible to the rest of the Burn project, you need to declare the module +within the +[`mod.rs` file](https://github.com/tracel-ai/burn/blob/9c5b07c833865bff7f82431001076a33d0d8729c/crates/burn-import/src/burn/node/mod.rs#L24) +located in the `src/burn/node/` directory. + +### Step 2: Node Implementation + +Create a new file named `.rs` in the `src/burn/node/` directory. +This file will define the structure and functionality of your new operation. By convention, the +necessary information for carrying out an operation is encapsulated within a struct named +`Node`. For the `Squeeze` operation, we defined a +[struct called `SqueezeNode`](https://github.com/tracel-ai/burn/blob/9c5b07c833865bff7f82431001076a33d0d8729c/crates/burn-import/src/burn/node/squeeze.rs#L8) +that holds necessary information about the input tensor, output tensor, and axes for the operation. +**If implementing a unary or binary operation, please see note below.** + +The core of integrating a new operation involves implementing the `NodeCodegen` trait for your node. +This trait defines how the node generates code during the graph compilation process. The +implementation must provide methods to define input and output types, to generate the forward pass +code, and to encapsulate the node into the more general `Node` structure. Specifically: + +- `output_types` and `input_types` return the tensor (or element) types for the output and inputs of + the node, respectively. +- `forward` generates the Rust code that performs the operation during the execution phase. The + `quote!` macro is used to generate rust code. Ensure that this is syntactically correct using Burn + code. +- `into_node` wraps the specific node in a general `Node` type, facilitating its inclusion in the + broader Burn graph structure. + +This file is also where you would put `test_codegen_nodes()`, to make sure that the generated code +works within the Burn library. + +**For unary and binary operations:** The implementation of `NodeCodegen` is mostly implemented in +[`binary.rs`](https://github.com/tracel-ai/burn/blob/76fe0ed881b3965782f78896433f8bb5e2f13a1b/crates/burn-import/src/burn/node/binary.rs#L9) +and +[`unary.rs`](https://github.com/tracel-ai/burn/blob/76fe0ed881b3965782f78896433f8bb5e2f13a1b/crates/burn-import/src/burn/node/unary.rs#L13), +so each new operation only has to define a method to execute the function on the input(s) token +stream. + +### Step 3: Registering New Operations + +[Register the `NodeType::`](https://github.com/tracel-ai/burn/blob/9c5b07c833865bff7f82431001076a33d0d8729c/crates/burn-import/src/onnx/to_burn.rs#L293) +and +[create an `_conversion(node: Node)` function](https://github.com/tracel-ai/burn/blob/9c5b07c833865bff7f82431001076a33d0d8729c/crates/burn-import/src/onnx/to_burn.rs#L831), +both in `src/onnx/to_burn.rs`. + +**Registering new operations in the ONNX -> Burn Conversion** +To integrate new operations from an ONNX graph into the Burn framework, each operation must be +registered within the ONNX graph conversion process. This is done in the `src/onnx/to_burn.rs` file, +where the conversion from ONNX nodes to Burn nodes is orchestrated. + +In the `into_burn()` method of the `OnnxGraph` struct, operations are matched with their +corresponding conversion functions. This method iterates over each node in the ONNX graph and, +depending on the node type, calls a specific conversion function that translates the ONNX node into +a corresponding Burn node. + +```rust +impl OnnxGraph { + pub fn into_burn(self) -> BurnGraph { + let mut graph = BurnGraph::::default(); + let mut unsupported_ops = vec![]; + + for node in self.nodes { + match node.node_type { + NodeType::Add => graph.register(Self::add_conversion(node)), + // Other operations... + NodeType::Squeeze => graph.register(Self::squeeze_conversion(node)), + // Add new operations here + } + } + } +} +``` + +Here, the `NodeType::Squeeze` matches the ONNX node type with the `squeeze_conversion()` function +that you define to handle the specific attributes and settings of a Squeeze operation. + +**Define the Conversion Function** +Each operation conversion function extracts necessary information from the ONNX node and constructs +a corresponding Burn node. The structure of these functions generally includes: + +1. Extracting input and output tensors from the node. +2. Retrieving and processing operation-specific configurations. +3. Calling `_config()` to parse ONNX node configurations. +4. Creating an instance of the appropriate Burn node + ([defined in step 2](#step-2-node-implementation)) using this information. + +### Step 4: Create a Config Function + +[Create an `_config(curr: &Node)`](https://github.com/tracel-ai/burn/blob/9c5b07c833865bff7f82431001076a33d0d8729c/crates/burn-import/src/onnx/op_configuration.rs#L975) +in `src/onnx/op_configuration.rs`. + +The `squeeze_conversion()` function in `src/onnx/to_burn.rs` from the previous step calls the +`squeeze_config()` function in `src/onnx/op_configuration.rs` in order the parse the ONNX node's +attributes to extract parameters specific to the Squeeze operation. In this case, the axes along +which the squeeze operation is performed. + +> ๐Ÿ“˜ Info: Understanding Generic `config` Patterns +> +> The `_config()` functions follow a similar pattern: +> +> 1. Extract tensor or scalar types for inputs and outputs. +> 2. Validate the input structure and types for each node, ensuring they conform to expected formats +> (panicking if not). +> 3. Parse and convert configurations or parameters specific to each operation. +> 4. Create and return a node specific to the operation, initialized with extracted values and +> configurations. +> +> For example, config functions handle specific settings like kernel size for pooling or handling +> different tensor and scalar types for power operations. + +These functions translate the more varied and flexible structure of ONNX nodes into the more +structured and type-safe environment of Rust and the Burn framework. Spec compliance is dealt with +here. + +### Step 5: Dimension Inference + +If needed, +[create a dimension inference function](https://github.com/tracel-ai/burn/blob/9c5b07c833865bff7f82431001076a33d0d8729c/crates/burn-import/src/onnx/dim_inference.rs#L271), +called `_update_output(node: &mut Node)` in `src/onnx/dim_inference.rs`. If dimensions +remain unchanged, use the `same_as_input()` function, for example +`NodeType::AveragePool1d => same_as_input(node)`. Match the `NodeType` to the function in the +`dim_inference()` match block. + +Dimension inference is an important step in the conversion process where Burn determines the +dimensions of each output tensor based on the operation. +[The `dim_inference()`](https://github.com/tracel-ai/burn/blob/9c5b07c833865bff7f82431001076a33d0d8729c/crates/burn-import/src/onnx/dim_inference.rs#L14) +function is responsible for determining the dimensions of the output tensors for each node in the +graph. It does this by: + +1. **Matching the Node Type**: The function uses a `match` statement on the `node_type` of each node + to apply the correct dimension inference logic depending on the operation. +2. **Applying Operation Specific Logic**: For each operation, a specific inference function is + called that encapsulate the rules for how output dimensions should be derived from the inputs. + +For the Squeeze operation, the dimension inference is handled by the `squeeze_update_output()` +function, which is specifically tailored to handle the nuances of the squeeze operation, which is +currently not that nuanced. The output tensor should be (dimensions of input tensor) - 1. + +> ๐Ÿ“˜ Info: How `squeeze_update_output()` Works +> +> 1. Validation of axes input: We first check if the second input of the node contains a list of +> integers, which represent the axes along which the squeeze operation is applied. The function +> also validates that only one axis is specified for squeezing, ensuring that the operation's +> requirements within Burn are followed. +> 2. Extracting input dimensions: The input tensor's dimension is extracted from the first input. +> 3. Configuring output dimensions: The output tensor's dimensions are then set to be one less than +> the input tensorโ€™s dimensions, reflecting the reduction in dimensions caused by the squeeze +> operation. +> 4. The function includes several checks that throw errors (panics) if the inputs do not meet the +> expected types or configurations, such as when the axes are not provided as an integer list or +> if the input type is not a tensor. + +By invoking this function within the `dim_inference()` match block, the output dimensions of each +node are updated before the graph is finalized. This ensures that all subsequent operations within +the graph can rely on correct tensor sizes, which is critical for both compiling the graph and for +runtime execution efficiency. + +If something is amiss (ie weird panics are happening), after doing this step and the dimensions of +your output tensor differs from the dimensions of your input, see the warning at the very end. + +### Step 6: Integrate into the Graph Building Process + +When a new node type is introduced, it must be added to the +[`Node` enum](https://github.com/tracel-ai/burn/blob/9c5b07c833865bff7f82431001076a33d0d8729c/crates/burn-import/src/burn/node/base.rs#L77) +and +[`match_all!` macro](https://github.com/tracel-ai/burn/blob/9c5b07c833865bff7f82431001076a33d0d8729c/crates/burn-import/src/burn/node/base.rs#L104) +in `src/burn/node/base.rs`. + +The `Node` enum abstracts over different types of operations (nodes) within a network graph. Each +variant of the enum corresponds to a specific type of operation, and it encapsulates the +operation-specific data structures (like `SqueezeNode1`) that was +[defined in step 2](#step-2-node-implementation). + +### Step 7: Add Newly Supported Op! + +As a reward, add an extra check to +[SUPPORTED-ONNX-OPS.md](https://github.com/tracel-ai/burn/blob/9c5b07c833865bff7f82431001076a33d0d8729c/crates/burn-import/SUPPORTED-ONNX-OPS.md?plain=1#L1)! + +### Misc: + +> ๐Ÿšง **Warning**: Dimension Changes +> +> If your operation changes the dimensions of the input tensor, you may need to modify the +> [`LIFT_CONSTANTS_FOR_NODE_TYPES` enum](https://github.com/tracel-ai/burn/blob/9c5b07c833865bff7f82431001076a33d0d8729c/crates/burn-import/src/onnx/from_onnx.rs#L20) +> in `src/onnx/from_onnx.rs` by adding the `NodeType` of your operation to it. + ## Testing - Unit tests for the Burn graph to Rust source code conversion are mandatory.