diff --git a/crates/burn-import/src/burn/node/base.rs b/crates/burn-import/src/burn/node/base.rs index 2727e28e7..88bdebfdb 100644 --- a/crates/burn-import/src/burn/node/base.rs +++ b/crates/burn-import/src/burn/node/base.rs @@ -160,6 +160,7 @@ macro_rules! match_all { Node::Range(node) => $func(node), Node::Reshape(node) => $func(node), Node::Resize(node) => $func(node), + Node::Scatter(node) => $func(node), Node::Slice(node) => $func(node), Node::Squeeze(node) => $func(node), Node::Sum(node) => $func(node), @@ -216,6 +217,7 @@ impl Node { Node::Range(_) => "range", Node::Reshape(_) => "reshape", Node::Resize(_) => "resize", + Node::Scatter(_) => "scatter", Node::Slice(_) => "slice", Node::Squeeze(_) => "squeeze", Node::Sum(_) => "add", diff --git a/crates/burn-import/src/burn/node/scatter.rs b/crates/burn-import/src/burn/node/scatter.rs index 628b109e0..43a490986 100644 --- a/crates/burn-import/src/burn/node/scatter.rs +++ b/crates/burn-import/src/burn/node/scatter.rs @@ -10,7 +10,7 @@ pub struct ScatterNode { pub indices: TensorType, pub updates: TensorType, pub output: TensorType, - axis: usize, + pub axis: usize, } impl NodeCodegen for ScatterNode { @@ -47,71 +47,79 @@ impl NodeCodegen for ScatterNode { } } -//#[cfg(test)] -//mod tests { -// -// use burn::record::FullPrecisionSettings; -// -// use super::*; -// use crate::burn::{ -// graph::BurnGraph, -// node::{gather::GatherNode, test::assert_tokens}, -// ScalarKind, ScalarType, ShapeType, TensorType, -// }; -// -// #[test] -// fn test_codegen_gather() { -// let mut graph = BurnGraph::::default(); -// -// graph.register(GatherNode::new( -// Type::Tensor(TensorType::new_float("tensor1", 2)), -// Type::Tensor(TensorType::new_int("tensor2", 1)), -// TensorType::new_float("tensor3", 2), -// 0, -// )); -// -// graph.register_input_output( -// vec!["tensor1".to_string(), "tensor2".to_string()], -// vec!["tensor3".to_string()], -// ); -// -// let expected = quote! { -// use burn::tensor::Int; -// use burn::{ -// module::Module, -// tensor::{backend::Backend, Tensor}, -// }; -// -// #[derive(Module, Debug)] -// pub struct Model { -// phantom: core::marker::PhantomData, -// device: burn::module::Ignored, -// } -// -// impl Model { -// #[allow(unused_variables)] -// pub fn new(device: &B::Device) -> Self { -// Self { -// phantom: core::marker::PhantomData, -// device: burn::module::Ignored(device.clone()), -// } -// } -// -// #[allow(clippy::let_and_return, clippy::approx_constant)] -// pub fn forward( -// &self, -// tensor1: Tensor, -// tensor2: Tensor -// ) -> Tensor { -// let tensor3 = tensor1.select(0, tensor2); -// -// tensor3 -// } -// } -// }; -// -// assert_tokens(graph.codegen(), expected); -// } +#[cfg(test)] +mod tests { + + use burn::record::FullPrecisionSettings; + + use super::*; + use crate::burn::{ + graph::BurnGraph, + node::{gather::GatherNode, test::assert_tokens}, + ScalarKind, ScalarType, ShapeType, TensorType, + }; + + #[test] + fn test_codegen_gather() { + let mut graph = BurnGraph::::default(); + + graph.register(ScatterNode::new( + TensorType::new_float("tensor1", 2), + TensorType::new_int("tensor2", 2), + TensorType::new_float("tensor3", 2), + TensorType::new_float("tensor4", 2), + 0, + )); + + graph.register_input_output( + vec![ + "tensor1".to_string(), + "tensor2".to_string(), + "tensor3".to_string(), + ], + vec!["tensor4".to_string()], + ); + + let expected = quote! { + use burn::tensor::Int; + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, + device: burn::module::Ignored, + } + + impl Model { + #[allow(unused_variables)] + pub fn new(device: &B::Device) -> Self { + Self { + phantom: core::marker::PhantomData, + device: burn::module::Ignored(device.clone()), + } + } + + #[allow(clippy::let_and_return, clippy::approx_constant)] + pub fn forward( + &self, + tensor1: Tensor, + tensor2: Tensor, + tensor3: Tensor, + ) -> Tensor { + let tensor4 = tensor1.scatter(0, tensor2, tensor3); + tensor4 + } + } + }; + + //println!(" {} ", graph.codegen()); + //assert!(false); + assert_tokens(graph.codegen(), expected); + } +} // // #[test] // fn test_codegen_gather_shape_input() { @@ -220,4 +228,3 @@ impl NodeCodegen for ScatterNode { // assert_tokens(graph.codegen(), expected); // } //} -