mirror of https://github.com/tracel-ai/burn.git
wip
This commit is contained in:
parent
bde8e60dc3
commit
27c7a33dee
|
@ -10,7 +10,7 @@ pub struct ScatterNode {
|
|||
pub indices: TensorType,
|
||||
pub updates: TensorType,
|
||||
pub output: TensorType,
|
||||
pub axis: usize,
|
||||
axis: usize,
|
||||
}
|
||||
|
||||
impl<PS: PrecisionSettings> NodeCodegen<PS> for ScatterNode {
|
||||
|
@ -22,6 +22,7 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for ScatterNode {
|
|||
vec![
|
||||
Type::Tensor(self.input.clone()),
|
||||
Type::Tensor(self.indices.clone()),
|
||||
Type::Tensor(self.updates.clone()),
|
||||
]
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue