mirror of https://github.com/tracel-ai/burn.git
working test
This commit is contained in:
parent
27c7a33dee
commit
c77a04666d
|
@ -160,6 +160,7 @@ macro_rules! match_all {
|
|||
Node::Range(node) => $func(node),
|
||||
Node::Reshape(node) => $func(node),
|
||||
Node::Resize(node) => $func(node),
|
||||
Node::Scatter(node) => $func(node),
|
||||
Node::Slice(node) => $func(node),
|
||||
Node::Squeeze(node) => $func(node),
|
||||
Node::Sum(node) => $func(node),
|
||||
|
@ -216,6 +217,7 @@ impl<PS: PrecisionSettings> Node<PS> {
|
|||
Node::Range(_) => "range",
|
||||
Node::Reshape(_) => "reshape",
|
||||
Node::Resize(_) => "resize",
|
||||
Node::Scatter(_) => "scatter",
|
||||
Node::Slice(_) => "slice",
|
||||
Node::Squeeze(_) => "squeeze",
|
||||
Node::Sum(_) => "add",
|
||||
|
|
|
@ -10,7 +10,7 @@ pub struct ScatterNode {
|
|||
pub indices: TensorType,
|
||||
pub updates: TensorType,
|
||||
pub output: TensorType,
|
||||
axis: usize,
|
||||
pub axis: usize,
|
||||
}
|
||||
|
||||
impl<PS: PrecisionSettings> NodeCodegen<PS> for ScatterNode {
|
||||
|
@ -47,71 +47,79 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for ScatterNode {
|
|||
}
|
||||
}
|
||||
|
||||
//#[cfg(test)]
|
||||
//mod tests {
|
||||
//
|
||||
// use burn::record::FullPrecisionSettings;
|
||||
//
|
||||
// use super::*;
|
||||
// use crate::burn::{
|
||||
// graph::BurnGraph,
|
||||
// node::{gather::GatherNode, test::assert_tokens},
|
||||
// ScalarKind, ScalarType, ShapeType, TensorType,
|
||||
// };
|
||||
//
|
||||
// #[test]
|
||||
// fn test_codegen_gather() {
|
||||
// let mut graph = BurnGraph::<FullPrecisionSettings>::default();
|
||||
//
|
||||
// graph.register(GatherNode::new(
|
||||
// Type::Tensor(TensorType::new_float("tensor1", 2)),
|
||||
// Type::Tensor(TensorType::new_int("tensor2", 1)),
|
||||
// TensorType::new_float("tensor3", 2),
|
||||
// 0,
|
||||
// ));
|
||||
//
|
||||
// graph.register_input_output(
|
||||
// vec!["tensor1".to_string(), "tensor2".to_string()],
|
||||
// vec!["tensor3".to_string()],
|
||||
// );
|
||||
//
|
||||
// let expected = quote! {
|
||||
// use burn::tensor::Int;
|
||||
// use burn::{
|
||||
// module::Module,
|
||||
// tensor::{backend::Backend, Tensor},
|
||||
// };
|
||||
//
|
||||
// #[derive(Module, Debug)]
|
||||
// pub struct Model<B: Backend> {
|
||||
// phantom: core::marker::PhantomData<B>,
|
||||
// device: burn::module::Ignored<B::Device>,
|
||||
// }
|
||||
//
|
||||
// impl<B: Backend> Model <B> {
|
||||
// #[allow(unused_variables)]
|
||||
// pub fn new(device: &B::Device) -> Self {
|
||||
// Self {
|
||||
// phantom: core::marker::PhantomData,
|
||||
// device: burn::module::Ignored(device.clone()),
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// #[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
// pub fn forward(
|
||||
// &self,
|
||||
// tensor1: Tensor<B, 2>,
|
||||
// tensor2: Tensor<B, 1, Int>
|
||||
// ) -> Tensor<B, 2> {
|
||||
// let tensor3 = tensor1.select(0, tensor2);
|
||||
//
|
||||
// tensor3
|
||||
// }
|
||||
// }
|
||||
// };
|
||||
//
|
||||
// assert_tokens(graph.codegen(), expected);
|
||||
// }
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use burn::record::FullPrecisionSettings;
|
||||
|
||||
use super::*;
|
||||
use crate::burn::{
|
||||
graph::BurnGraph,
|
||||
node::{gather::GatherNode, test::assert_tokens},
|
||||
ScalarKind, ScalarType, ShapeType, TensorType,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn test_codegen_gather() {
|
||||
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
|
||||
|
||||
graph.register(ScatterNode::new(
|
||||
TensorType::new_float("tensor1", 2),
|
||||
TensorType::new_int("tensor2", 2),
|
||||
TensorType::new_float("tensor3", 2),
|
||||
TensorType::new_float("tensor4", 2),
|
||||
0,
|
||||
));
|
||||
|
||||
graph.register_input_output(
|
||||
vec![
|
||||
"tensor1".to_string(),
|
||||
"tensor2".to_string(),
|
||||
"tensor3".to_string(),
|
||||
],
|
||||
vec!["tensor4".to_string()],
|
||||
);
|
||||
|
||||
let expected = quote! {
|
||||
use burn::tensor::Int;
|
||||
use burn::{
|
||||
module::Module,
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Model<B: Backend> {
|
||||
phantom: core::marker::PhantomData<B>,
|
||||
device: burn::module::Ignored<B::Device>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Model <B> {
|
||||
#[allow(unused_variables)]
|
||||
pub fn new(device: &B::Device) -> Self {
|
||||
Self {
|
||||
phantom: core::marker::PhantomData,
|
||||
device: burn::module::Ignored(device.clone()),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::let_and_return, clippy::approx_constant)]
|
||||
pub fn forward(
|
||||
&self,
|
||||
tensor1: Tensor<B, 2>,
|
||||
tensor2: Tensor<B, 2, Int>,
|
||||
tensor3: Tensor<B, 2>,
|
||||
) -> Tensor<B, 2> {
|
||||
let tensor4 = tensor1.scatter(0, tensor2, tensor3);
|
||||
tensor4
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
//println!(" {} ", graph.codegen());
|
||||
//assert!(false);
|
||||
assert_tokens(graph.codegen(), expected);
|
||||
}
|
||||
}
|
||||
//
|
||||
// #[test]
|
||||
// fn test_codegen_gather_shape_input() {
|
||||
|
@ -220,4 +228,3 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for ScatterNode {
|
|||
// assert_tokens(graph.codegen(), expected);
|
||||
// }
|
||||
//}
|
||||
|
||||
|
|
Loading…
Reference in New Issue