Support more ONNX ops. (#1267)
* Add LogSoftmax. * Support for Transpose.
This commit is contained in:
parent
bc9a1bf239
commit
dc68c130e4
|
@ -59,6 +59,26 @@ pub fn simple_eval(
|
|||
Ok(dt.i)
|
||||
}
|
||||
};
|
||||
let get_attr_is = |name: &str| match node.attribute.iter().find(|attr| attr.name == name) {
|
||||
None => {
|
||||
bail!(
|
||||
"cannot find the '{name}' attribute in '{}' for {}",
|
||||
node.op_type,
|
||||
node.name
|
||||
)
|
||||
}
|
||||
Some(dt) => {
|
||||
match dt.r#type() {
|
||||
AttributeType::Ints => (),
|
||||
rtype => bail!(
|
||||
"unsupported type {rtype:?} for '{name}' attribute in '{}' for {}",
|
||||
node.op_type,
|
||||
node.name
|
||||
),
|
||||
}
|
||||
Ok(dt.ints.as_slice())
|
||||
}
|
||||
};
|
||||
// TODO: Validate node.input for each operator.
|
||||
match node.op_type.as_str() {
|
||||
"Add" => {
|
||||
|
@ -114,6 +134,24 @@ pub fn simple_eval(
|
|||
let output = input0.reshape(input1)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"LogSoftmax" => {
|
||||
let input = get(&node.input[0])?;
|
||||
let output = match get_attr_i("axis") {
|
||||
Err(_) => candle_nn::ops::softmax_last_dim(input)?,
|
||||
Ok(axis) => {
|
||||
let num_axis = input.rank() as i64;
|
||||
let axis = if axis >= 0 {
|
||||
axis as usize
|
||||
} else if axis < -num_axis {
|
||||
bail!("wrong axis in concat {axis} for shape {:?}", input.shape())
|
||||
} else {
|
||||
(num_axis - axis) as usize
|
||||
};
|
||||
candle_nn::ops::log_softmax(input, axis)?
|
||||
}
|
||||
};
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Softmax" => {
|
||||
let input = get(&node.input[0])?;
|
||||
let output = match get_attr_i("axis") {
|
||||
|
@ -132,6 +170,17 @@ pub fn simple_eval(
|
|||
};
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Transpose" => {
|
||||
let input = get(&node.input[0])?;
|
||||
let output = match get_attr_is("perm") {
|
||||
Err(_) => input.t()?,
|
||||
Ok(perm) => {
|
||||
let perm = perm.iter().map(|&v| v as usize).collect::<Vec<_>>();
|
||||
input.permute(perm)?
|
||||
}
|
||||
};
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Concat" => {
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Concat
|
||||
let inputs = node
|
||||
|
|
Loading…
Reference in New Issue