fix: negative axis (#1296)
* fix: negative axis * Use normalize_axis. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
parent
f772213e84
commit
73d02f4f57
|
@ -298,14 +298,7 @@ pub fn simple_eval(
|
|||
let output = match get_attr_opt::<i64>(node, "axis")? {
|
||||
None => candle_nn::ops::softmax_last_dim(input)?,
|
||||
Some(&axis) => {
|
||||
let num_axis = input.rank() as i64;
|
||||
let axis = if axis >= 0 {
|
||||
axis as usize
|
||||
} else if axis < -num_axis {
|
||||
bail!("wrong axis in concat {axis} for shape {:?}", input.shape())
|
||||
} else {
|
||||
(num_axis - axis) as usize
|
||||
};
|
||||
let axis = input.normalize_axis(axis)?;
|
||||
candle_nn::ops::log_softmax(input, axis)?
|
||||
}
|
||||
};
|
||||
|
@ -316,14 +309,7 @@ pub fn simple_eval(
|
|||
let output = match get_attr_opt::<i64>(node, "axis")? {
|
||||
None => candle_nn::ops::softmax_last_dim(input)?,
|
||||
Some(&axis) => {
|
||||
let num_axis = input.rank() as i64;
|
||||
let axis = if axis >= 0 {
|
||||
axis as usize
|
||||
} else if axis < -num_axis {
|
||||
bail!("wrong axis in concat {axis} for shape {:?}", input.shape())
|
||||
} else {
|
||||
(num_axis - axis) as usize
|
||||
};
|
||||
let axis = input.normalize_axis(axis)?;
|
||||
candle_nn::ops::softmax(input, axis)?
|
||||
}
|
||||
};
|
||||
|
@ -666,21 +652,10 @@ pub fn simple_eval(
|
|||
.map(|n| Ok(get(n.as_str())?.clone()))
|
||||
.collect::<Result<Vec<Value>>>()?;
|
||||
let axis: i64 = *get_attr(node, "axis")?;
|
||||
let num_axis = if inputs.is_empty() {
|
||||
if inputs.is_empty() {
|
||||
bail!("empty concat")
|
||||
} else {
|
||||
inputs[0].rank() as i64
|
||||
};
|
||||
let axis = if axis >= 0 {
|
||||
axis as usize
|
||||
} else if axis < -num_axis {
|
||||
bail!(
|
||||
"wrong axis in concat {axis} for shape {:?}",
|
||||
inputs[0].shape()
|
||||
)
|
||||
} else {
|
||||
(num_axis - axis) as usize
|
||||
};
|
||||
let axis = inputs[0].normalize_axis(axis)?;
|
||||
let output = Tensor::cat(&inputs, axis)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue