mirror of https://github.com/tracel-ai/burn.git
wip
This commit is contained in:
parent
bde8e60dc3
commit
27c7a33dee
|
@ -10,7 +10,7 @@ pub struct ScatterNode {
|
||||||
pub indices: TensorType,
|
pub indices: TensorType,
|
||||||
pub updates: TensorType,
|
pub updates: TensorType,
|
||||||
pub output: TensorType,
|
pub output: TensorType,
|
||||||
pub axis: usize,
|
axis: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<PS: PrecisionSettings> NodeCodegen<PS> for ScatterNode {
|
impl<PS: PrecisionSettings> NodeCodegen<PS> for ScatterNode {
|
||||||
|
@ -22,6 +22,7 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for ScatterNode {
|
||||||
vec![
|
vec![
|
||||||
Type::Tensor(self.input.clone()),
|
Type::Tensor(self.input.clone()),
|
||||||
Type::Tensor(self.indices.clone()),
|
Type::Tensor(self.indices.clone()),
|
||||||
|
Type::Tensor(self.updates.clone()),
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue