onnx: support negative index in Gather (#2440)
index_select does not support negative indexing, but this change adds just enough workarounds in onnx to allow evaluating silero-vad models (which make use of negative indices).
This commit is contained in:
parent
a8288b7a72
commit
1e96b8b695
|
@ -629,6 +629,18 @@ fn simple_eval_(
|
|||
let axis = get_attr_opt::<i64>(node, "axis")?.copied().unwrap_or(0);
|
||||
let axis = xs.normalize_axis(axis)?;
|
||||
|
||||
// index_select does not support negative indices, so normalize them
|
||||
// to positive indices.
|
||||
let indices = &{
|
||||
let zeros = Tensor::zeros(indices.shape(), indices.dtype(), indices.device())?;
|
||||
let max = Tensor::new(xs.dims()[axis] as i64, indices.device())?
|
||||
.to_dtype(indices.dtype())?;
|
||||
let mask = indices.lt(&zeros)?;
|
||||
mask.to_dtype(indices.dtype())?
|
||||
.broadcast_mul(&max)?
|
||||
.add(&indices)?
|
||||
};
|
||||
|
||||
// In Pytorch or Numpy this can be done by indexing the xs tensor using the indices
|
||||
// tensor directly, but candle does not support tensor indexing at the moment, so
|
||||
// some workarounds must be done.
|
||||
|
|
Loading…
Reference in New Issue