Support more ONNX ops. (#1267)

* Add LogSoftmax.

* Support for Transpose.
This commit is contained in:
Laurent Mazare 2023-11-04 15:10:14 +01:00 committed by GitHub
parent bc9a1bf239
commit dc68c130e4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 49 additions and 0 deletions

View File

@ -59,6 +59,26 @@ pub fn simple_eval(
Ok(dt.i)
}
};
let get_attr_is = |name: &str| match node.attribute.iter().find(|attr| attr.name == name) {
None => {
bail!(
"cannot find the '{name}' attribute in '{}' for {}",
node.op_type,
node.name
)
}
Some(dt) => {
match dt.r#type() {
AttributeType::Ints => (),
rtype => bail!(
"unsupported type {rtype:?} for '{name}' attribute in '{}' for {}",
node.op_type,
node.name
),
}
Ok(dt.ints.as_slice())
}
};
// TODO: Validate node.input for each operator.
match node.op_type.as_str() {
"Add" => {
@ -114,6 +134,24 @@ pub fn simple_eval(
let output = input0.reshape(input1)?;
values.insert(node.output[0].clone(), output);
}
"LogSoftmax" => {
let input = get(&node.input[0])?;
let output = match get_attr_i("axis") {
Err(_) => candle_nn::ops::softmax_last_dim(input)?,
Ok(axis) => {
let num_axis = input.rank() as i64;
let axis = if axis >= 0 {
axis as usize
} else if axis < -num_axis {
bail!("wrong axis in concat {axis} for shape {:?}", input.shape())
} else {
(num_axis - axis) as usize
};
candle_nn::ops::log_softmax(input, axis)?
}
};
values.insert(node.output[0].clone(), output);
}
"Softmax" => {
let input = get(&node.input[0])?;
let output = match get_attr_i("axis") {
@ -132,6 +170,17 @@ pub fn simple_eval(
};
values.insert(node.output[0].clone(), output);
}
"Transpose" => {
let input = get(&node.input[0])?;
let output = match get_attr_is("perm") {
Err(_) => input.t()?,
Ok(perm) => {
let perm = perm.iter().map(|&v| v as usize).collect::<Vec<_>>();
input.permute(perm)?
}
};
values.insert(node.output[0].clone(), output);
}
"Concat" => {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Concat
let inputs = node