fix: negative axis (#1296)

* fix: negative axis

* Use normalize_axis.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
YangNianYi 2023-11-09 06:28:21 +08:00 committed by GitHub
parent f772213e84
commit 73d02f4f57
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 29 deletions

View File

@ -298,14 +298,7 @@ pub fn simple_eval(
let output = match get_attr_opt::<i64>(node, "axis")? {
None => candle_nn::ops::softmax_last_dim(input)?,
Some(&axis) => {
let num_axis = input.rank() as i64;
let axis = if axis >= 0 {
axis as usize
} else if axis < -num_axis {
bail!("wrong axis in concat {axis} for shape {:?}", input.shape())
} else {
(num_axis - axis) as usize
};
let axis = input.normalize_axis(axis)?;
candle_nn::ops::log_softmax(input, axis)?
}
};
@ -316,14 +309,7 @@ pub fn simple_eval(
let output = match get_attr_opt::<i64>(node, "axis")? {
None => candle_nn::ops::softmax_last_dim(input)?,
Some(&axis) => {
let num_axis = input.rank() as i64;
let axis = if axis >= 0 {
axis as usize
} else if axis < -num_axis {
bail!("wrong axis in concat {axis} for shape {:?}", input.shape())
} else {
(num_axis - axis) as usize
};
let axis = input.normalize_axis(axis)?;
candle_nn::ops::softmax(input, axis)?
}
};
@ -666,21 +652,10 @@ pub fn simple_eval(
.map(|n| Ok(get(n.as_str())?.clone()))
.collect::<Result<Vec<Value>>>()?;
let axis: i64 = *get_attr(node, "axis")?;
let num_axis = if inputs.is_empty() {
if inputs.is_empty() {
bail!("empty concat")
} else {
inputs[0].rank() as i64
};
let axis = if axis >= 0 {
axis as usize
} else if axis < -num_axis {
bail!(
"wrong axis in concat {axis} for shape {:?}",
inputs[0].shape()
)
} else {
(num_axis - axis) as usize
};
let axis = inputs[0].normalize_axis(axis)?;
let output = Tensor::cat(&inputs, axis)?;
values.insert(node.output[0].clone(), output);
}