onnx: support negative index in Gather (#2440)

index_select does not support negative indexing, but
this change adds just enough workarounds in onnx to
allow evaluating silero-vad models (which make use of
negative indices).
This commit is contained in:
shua 2024-08-22 15:28:25 +02:00 committed by GitHub
parent a8288b7a72
commit 1e96b8b695
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 12 additions and 0 deletions

View File

@ -629,6 +629,18 @@ fn simple_eval_(
let axis = get_attr_opt::<i64>(node, "axis")?.copied().unwrap_or(0);
let axis = xs.normalize_axis(axis)?;
// index_select does not support negative indices, so normalize them
// to positive indices.
let indices = &{
let zeros = Tensor::zeros(indices.shape(), indices.dtype(), indices.device())?;
let max = Tensor::new(xs.dims()[axis] as i64, indices.device())?
.to_dtype(indices.dtype())?;
let mask = indices.lt(&zeros)?;
mask.to_dtype(indices.dtype())?
.broadcast_mul(&max)?
.add(&indices)?
};
// In Pytorch or Numpy this can be done by indexing the xs tensor using the indices
// tensor directly, but candle does not support tensor indexing at the moment, so
// some workarounds must be done.