ONNX casting support. (#1265)

* ONNX casting support.

* Handle tensor constants.

* Bugfix the binary ops.
This commit is contained in:
Laurent Mazare 2023-11-04 08:34:24 +01:00 committed by GitHub
parent 8cbb9d0e6c
commit f7c957d64f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 94 additions and 10 deletions

View File

@ -1,5 +1,5 @@
use crate::onnx;
use candle::{Result, Tensor};
use candle::{bail, DType, Device, Result, Tensor};
use std::collections::HashMap;
pub type Value = Tensor;
@ -13,8 +13,9 @@ pub fn simple_eval(
model: &onnx::ModelProto,
inputs: HashMap<String, Value>,
) -> Result<HashMap<String, Value>> {
use crate::onnx::attribute_proto::AttributeType;
let graph = match &model.graph {
None => candle::bail!("no graph defined in proto"),
None => bail!("no graph defined in proto"),
Some(graph) => graph,
};
// TODO: validate the inputs.
@ -23,37 +24,37 @@ pub fn simple_eval(
for node in graph.node.iter() {
let get = |input_name: &str| match values.get(input_name) {
Some(value) => Ok(value),
None => candle::bail!("cannot find {input_name} for op {}", node.name),
None => bail!("cannot find {input_name} for op {}", node.name),
};
// TODO: Validate node.input for each operator.
match node.op_type.as_str() {
"Add" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[0])?;
let input1 = get(&node.input[1])?;
let output = input0.broadcast_add(input1)?;
values.insert(node.output[0].clone(), output);
}
"Sub" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[0])?;
let input1 = get(&node.input[1])?;
let output = input0.broadcast_sub(input1)?;
values.insert(node.output[0].clone(), output);
}
"Mul" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[0])?;
let input1 = get(&node.input[1])?;
let output = input0.broadcast_mul(input1)?;
values.insert(node.output[0].clone(), output);
}
"Div" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[0])?;
let input1 = get(&node.input[1])?;
let output = input0.broadcast_div(input1)?;
values.insert(node.output[0].clone(), output);
}
"MatMul" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[0])?;
let input1 = get(&node.input[1])?;
let output = input0.broadcast_matmul(input1)?;
values.insert(node.output[0].clone(), output);
}
@ -67,14 +68,97 @@ pub fn simple_eval(
let output = input.relu()?;
values.insert(node.output[0].clone(), output);
}
op_type => candle::bail!("unsupported op_type {op_type} for op {}", node.name),
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Constant
"Constant" => {
let value = match node.attribute.iter().find(|attr| attr.name == "value") {
None => {
// TODO: support sparse_value etc.
bail!("cannot find 'value' attr in 'Constant' for {}", node.name)
}
Some(value) => value,
};
let output = match value.r#type() {
AttributeType::Tensor => {
use crate::onnx::tensor_proto::DataType;
let t = value.t.as_ref().unwrap();
let dims: Vec<usize> = t.dims.iter().map(|&x| x as usize).collect();
match DataType::try_from(t.data_type) {
Ok(DataType::Uint8) => Tensor::from_raw_buffer(
t.raw_data.as_slice(),
DType::U8,
dims.as_slice(),
&Device::Cpu,
)?,
Ok(DataType::Uint32) => Tensor::from_raw_buffer(
t.raw_data.as_slice(),
DType::U32,
dims.as_slice(),
&Device::Cpu,
)?,
Ok(DataType::Int64) => Tensor::from_raw_buffer(
t.raw_data.as_slice(),
DType::I64,
dims.as_slice(),
&Device::Cpu,
)?,
Ok(DataType::Float16) => Tensor::from_raw_buffer(
t.raw_data.as_slice(),
DType::F16,
dims.as_slice(),
&Device::Cpu,
)?,
Ok(DataType::Float) => Tensor::from_raw_buffer(
t.raw_data.as_slice(),
DType::F32,
dims.as_slice(),
&Device::Cpu,
)?,
Ok(DataType::Double) => Tensor::from_raw_buffer(
t.raw_data.as_slice(),
DType::F64,
dims.as_slice(),
&Device::Cpu,
)?,
Ok(dt) => {
bail!("unsupported 'value' data-type {dt:?} for {}", node.name)
}
Err(_) => {
bail!(
"unsupported 'value' data-type {} for {}",
t.data_type,
node.name
)
}
}
}
rtype => bail!("unsupported 'value' type {rtype:?} for {}", node.name),
};
values.insert(node.output[0].clone(), output);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cast
"Cast" => {
let input = get(&node.input[0])?;
let dtype = match node.attribute.iter().find(|attr| attr.name == "to") {
None => {
bail!("cannot find the 'to' attribute in 'Cast' for {}", node.name)
}
Some(dtype) => match dtype.r#type() {
AttributeType::Floats => candle::DType::F32,
AttributeType::Int => candle::DType::I64,
rtype => bail!("unsupported 'to' type {rtype:?} for {}", node.name),
},
};
let output = input.to_dtype(dtype)?;
values.insert(node.output[0].clone(), output);
}
op_type => bail!("unsupported op_type {op_type} for op {}", node.name),
}
}
graph
.output
.iter()
.map(|output| match values.remove(&output.name) {
None => candle::bail!("cannot find output {}", output.name),
None => bail!("cannot find output {}", output.name),
Some(value) => Ok((output.name.clone(), value)),
})
.collect()