working test

This commit is contained in:
mepatrick73 2024-08-23 19:03:19 -04:00
parent 27c7a33dee
commit c77a04666d
2 changed files with 76 additions and 67 deletions

View File

@ -160,6 +160,7 @@ macro_rules! match_all {
Node::Range(node) => $func(node),
Node::Reshape(node) => $func(node),
Node::Resize(node) => $func(node),
Node::Scatter(node) => $func(node),
Node::Slice(node) => $func(node),
Node::Squeeze(node) => $func(node),
Node::Sum(node) => $func(node),
@ -216,6 +217,7 @@ impl<PS: PrecisionSettings> Node<PS> {
Node::Range(_) => "range",
Node::Reshape(_) => "reshape",
Node::Resize(_) => "resize",
Node::Scatter(_) => "scatter",
Node::Slice(_) => "slice",
Node::Squeeze(_) => "squeeze",
Node::Sum(_) => "add",

View File

@ -10,7 +10,7 @@ pub struct ScatterNode {
pub indices: TensorType,
pub updates: TensorType,
pub output: TensorType,
axis: usize,
pub axis: usize,
}
impl<PS: PrecisionSettings> NodeCodegen<PS> for ScatterNode {
@ -47,71 +47,79 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for ScatterNode {
}
}
//#[cfg(test)]
//mod tests {
//
// use burn::record::FullPrecisionSettings;
//
// use super::*;
// use crate::burn::{
// graph::BurnGraph,
// node::{gather::GatherNode, test::assert_tokens},
// ScalarKind, ScalarType, ShapeType, TensorType,
// };
//
// #[test]
// fn test_codegen_gather() {
// let mut graph = BurnGraph::<FullPrecisionSettings>::default();
//
// graph.register(GatherNode::new(
// Type::Tensor(TensorType::new_float("tensor1", 2)),
// Type::Tensor(TensorType::new_int("tensor2", 1)),
// TensorType::new_float("tensor3", 2),
// 0,
// ));
//
// graph.register_input_output(
// vec!["tensor1".to_string(), "tensor2".to_string()],
// vec!["tensor3".to_string()],
// );
//
// let expected = quote! {
// use burn::tensor::Int;
// use burn::{
// module::Module,
// tensor::{backend::Backend, Tensor},
// };
//
// #[derive(Module, Debug)]
// pub struct Model<B: Backend> {
// phantom: core::marker::PhantomData<B>,
// device: burn::module::Ignored<B::Device>,
// }
//
// impl<B: Backend> Model <B> {
// #[allow(unused_variables)]
// pub fn new(device: &B::Device) -> Self {
// Self {
// phantom: core::marker::PhantomData,
// device: burn::module::Ignored(device.clone()),
// }
// }
//
// #[allow(clippy::let_and_return, clippy::approx_constant)]
// pub fn forward(
// &self,
// tensor1: Tensor<B, 2>,
// tensor2: Tensor<B, 1, Int>
// ) -> Tensor<B, 2> {
// let tensor3 = tensor1.select(0, tensor2);
//
// tensor3
// }
// }
// };
//
// assert_tokens(graph.codegen(), expected);
// }
#[cfg(test)]
mod tests {
use burn::record::FullPrecisionSettings;
use super::*;
use crate::burn::{
graph::BurnGraph,
node::{gather::GatherNode, test::assert_tokens},
ScalarKind, ScalarType, ShapeType, TensorType,
};
#[test]
fn test_codegen_gather() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
graph.register(ScatterNode::new(
TensorType::new_float("tensor1", 2),
TensorType::new_int("tensor2", 2),
TensorType::new_float("tensor3", 2),
TensorType::new_float("tensor4", 2),
0,
));
graph.register_input_output(
vec![
"tensor1".to_string(),
"tensor2".to_string(),
"tensor3".to_string(),
],
vec!["tensor4".to_string()],
);
let expected = quote! {
use burn::tensor::Int;
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}
impl<B: Backend> Model <B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(
&self,
tensor1: Tensor<B, 2>,
tensor2: Tensor<B, 2, Int>,
tensor3: Tensor<B, 2>,
) -> Tensor<B, 2> {
let tensor4 = tensor1.scatter(0, tensor2, tensor3);
tensor4
}
}
};
//println!(" {} ", graph.codegen());
//assert!(false);
assert_tokens(graph.codegen(), expected);
}
}
//
// #[test]
// fn test_codegen_gather_shape_input() {
@ -220,4 +228,3 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for ScatterNode {
// assert_tokens(graph.codegen(), expected);
// }
//}