Refactor the onnx attribute getters. (#1268)
* Refactor the onnx attribute getters. * Add get-attr-opt. * Add support for convolutions. * Add support for convolutions.
This commit is contained in:
parent
7051fb8098
commit
b5e4f84bed
|
@ -1,4 +1,5 @@
|
|||
use crate::onnx;
|
||||
use crate::onnx::attribute_proto::AttributeType;
|
||||
use crate::onnx::tensor_proto::DataType;
|
||||
use candle::{bail, DType, Device, Result, Tensor};
|
||||
use std::collections::HashMap;
|
||||
|
@ -17,6 +18,96 @@ pub fn dtype(dt: DataType) -> Option<DType> {
|
|||
}
|
||||
}
|
||||
|
||||
trait Attr {
|
||||
const TYPE: AttributeType;
|
||||
fn get(attr: &onnx::AttributeProto) -> Result<&Self>;
|
||||
}
|
||||
|
||||
impl Attr for i64 {
|
||||
const TYPE: AttributeType = AttributeType::Int;
|
||||
fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
|
||||
Ok(&attr.i)
|
||||
}
|
||||
}
|
||||
|
||||
impl Attr for [i64] {
|
||||
const TYPE: AttributeType = AttributeType::Ints;
|
||||
fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
|
||||
Ok(attr.ints.as_slice())
|
||||
}
|
||||
}
|
||||
|
||||
impl Attr for str {
|
||||
const TYPE: AttributeType = AttributeType::String;
|
||||
fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
|
||||
std::str::from_utf8(&attr.s).map_err(candle::Error::wrap)
|
||||
}
|
||||
}
|
||||
|
||||
fn get_attr_<'a>(node: &'a onnx::NodeProto, name: &str) -> Result<&'a onnx::AttributeProto> {
|
||||
match node.attribute.iter().find(|attr| attr.name == name) {
|
||||
None => {
|
||||
bail!(
|
||||
"cannot find the '{name}' attribute in '{}' for {}",
|
||||
node.op_type,
|
||||
node.name
|
||||
)
|
||||
}
|
||||
Some(dt) => Ok(dt),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_attr<'a, T: Attr + ?Sized>(node: &'a onnx::NodeProto, name: &str) -> Result<&'a T> {
|
||||
let attr = get_attr_(node, name)?;
|
||||
if attr.r#type() != T::TYPE {
|
||||
bail!(
|
||||
"unsupported type {:?} for '{name}' attribute in '{}' for {}",
|
||||
attr.r#type,
|
||||
node.op_type,
|
||||
node.name
|
||||
)
|
||||
}
|
||||
T::get(attr)
|
||||
}
|
||||
|
||||
fn get_attr_opt<'a, T: Attr + ?Sized>(
|
||||
node: &'a onnx::NodeProto,
|
||||
name: &str,
|
||||
) -> Result<Option<&'a T>> {
|
||||
match node.attribute.iter().find(|attr| attr.name == name) {
|
||||
None => Ok(None),
|
||||
Some(attr) => {
|
||||
if attr.r#type() != T::TYPE {
|
||||
bail!(
|
||||
"unsupported type {:?} for '{name}' attribute in '{}' for {}",
|
||||
attr.r#type,
|
||||
node.op_type,
|
||||
node.name
|
||||
)
|
||||
}
|
||||
let val = T::get(attr)?;
|
||||
Ok(Some(val))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_tensor(t: &onnx::TensorProto, name: &str) -> Result<Tensor> {
|
||||
let dims: Vec<usize> = t.dims.iter().map(|&x| x as usize).collect();
|
||||
match DataType::try_from(t.data_type) {
|
||||
Ok(dt) => match dtype(dt) {
|
||||
Some(dt) => {
|
||||
Tensor::from_raw_buffer(t.raw_data.as_slice(), dt, dims.as_slice(), &Device::Cpu)
|
||||
}
|
||||
None => {
|
||||
bail!("unsupported 'value' data-type {dt:?} for {name}")
|
||||
}
|
||||
},
|
||||
Err(_) => {
|
||||
bail!("unsupported 'value' data-type {} for {name}", t.data_type,)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This function provides a direct evaluation of the proto.
|
||||
// Longer-term, we should first convert the proto to an intermediate representation of the compute
|
||||
// graph so as to make multiple evaluations more efficient.
|
||||
|
@ -26,59 +117,22 @@ pub fn simple_eval(
|
|||
model: &onnx::ModelProto,
|
||||
inputs: HashMap<String, Value>,
|
||||
) -> Result<HashMap<String, Value>> {
|
||||
use crate::onnx::attribute_proto::AttributeType;
|
||||
let graph = match &model.graph {
|
||||
None => bail!("no graph defined in proto"),
|
||||
Some(graph) => graph,
|
||||
};
|
||||
// TODO: validate the inputs.
|
||||
let mut values = inputs;
|
||||
for t in graph.initializer.iter() {
|
||||
let tensor = get_tensor(t, t.name.as_str())?;
|
||||
values.insert(t.name.to_string(), tensor);
|
||||
}
|
||||
// The nodes are topologically sorted so we can just process them in order.
|
||||
for node in graph.node.iter() {
|
||||
let get = |input_name: &str| match values.get(input_name) {
|
||||
Some(value) => Ok(value),
|
||||
None => bail!("cannot find {input_name} for op {}", node.name),
|
||||
};
|
||||
let get_attr_i = |name: &str| match node.attribute.iter().find(|attr| attr.name == name) {
|
||||
None => {
|
||||
bail!(
|
||||
"cannot find the '{name}' attribute in '{}' for {}",
|
||||
node.op_type,
|
||||
node.name
|
||||
)
|
||||
}
|
||||
Some(dt) => {
|
||||
match dt.r#type() {
|
||||
AttributeType::Int => (),
|
||||
rtype => bail!(
|
||||
"unsupported type {rtype:?} for '{name}' attribute in '{}' for {}",
|
||||
node.op_type,
|
||||
node.name
|
||||
),
|
||||
}
|
||||
Ok(dt.i)
|
||||
}
|
||||
};
|
||||
let get_attr_is = |name: &str| match node.attribute.iter().find(|attr| attr.name == name) {
|
||||
None => {
|
||||
bail!(
|
||||
"cannot find the '{name}' attribute in '{}' for {}",
|
||||
node.op_type,
|
||||
node.name
|
||||
)
|
||||
}
|
||||
Some(dt) => {
|
||||
match dt.r#type() {
|
||||
AttributeType::Ints => (),
|
||||
rtype => bail!(
|
||||
"unsupported type {rtype:?} for '{name}' attribute in '{}' for {}",
|
||||
node.op_type,
|
||||
node.name
|
||||
),
|
||||
}
|
||||
Ok(dt.ints.as_slice())
|
||||
}
|
||||
};
|
||||
// TODO: Validate node.input for each operator.
|
||||
match node.op_type.as_str() {
|
||||
"Add" => {
|
||||
|
@ -136,9 +190,9 @@ pub fn simple_eval(
|
|||
}
|
||||
"LogSoftmax" => {
|
||||
let input = get(&node.input[0])?;
|
||||
let output = match get_attr_i("axis") {
|
||||
Err(_) => candle_nn::ops::softmax_last_dim(input)?,
|
||||
Ok(axis) => {
|
||||
let output = match get_attr_opt::<i64>(node, "axis")? {
|
||||
None => candle_nn::ops::softmax_last_dim(input)?,
|
||||
Some(&axis) => {
|
||||
let num_axis = input.rank() as i64;
|
||||
let axis = if axis >= 0 {
|
||||
axis as usize
|
||||
|
@ -154,9 +208,9 @@ pub fn simple_eval(
|
|||
}
|
||||
"Softmax" => {
|
||||
let input = get(&node.input[0])?;
|
||||
let output = match get_attr_i("axis") {
|
||||
Err(_) => candle_nn::ops::softmax_last_dim(input)?,
|
||||
Ok(axis) => {
|
||||
let output = match get_attr_opt::<i64>(node, "axis")? {
|
||||
None => candle_nn::ops::softmax_last_dim(input)?,
|
||||
Some(&axis) => {
|
||||
let num_axis = input.rank() as i64;
|
||||
let axis = if axis >= 0 {
|
||||
axis as usize
|
||||
|
@ -172,15 +226,126 @@ pub fn simple_eval(
|
|||
}
|
||||
"Transpose" => {
|
||||
let input = get(&node.input[0])?;
|
||||
let output = match get_attr_is("perm") {
|
||||
Err(_) => input.t()?,
|
||||
Ok(perm) => {
|
||||
let output = match get_attr_opt::<[i64]>(node, "perm")? {
|
||||
None => input.t()?,
|
||||
Some(perm) => {
|
||||
let perm = perm.iter().map(|&v| v as usize).collect::<Vec<_>>();
|
||||
input.permute(perm)?
|
||||
}
|
||||
};
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Conv" => {
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Conv
|
||||
let dilations = get_attr_opt::<[i64]>(node, "dilations")?;
|
||||
let groups = get_attr_opt::<i64>(node, "group")?.copied().unwrap_or(1);
|
||||
let _kernel_shape = get_attr_opt::<[i64]>(node, "kernel_shape")?;
|
||||
let pads = get_attr_opt::<[i64]>(node, "pads")?;
|
||||
let strides = get_attr_opt::<[i64]>(node, "strides")?;
|
||||
let auto_pad = get_attr_opt::<str>(node, "auto_pad")?;
|
||||
match auto_pad {
|
||||
None | Some("NOTSET") => (),
|
||||
Some(s) => bail!("unsupported auto_pad {s}"),
|
||||
};
|
||||
let xs = get(&node.input[0])?;
|
||||
let ws = get(&node.input[1])?;
|
||||
let ys = match ws.rank() {
|
||||
3 => {
|
||||
let pads = match pads {
|
||||
None => 0,
|
||||
Some([p]) => *p as usize,
|
||||
Some([p1, p2]) => {
|
||||
if p1 != p2 {
|
||||
bail!(
|
||||
"left and right pad ({p1} <> {p2}) have to be the same {}",
|
||||
node.name
|
||||
)
|
||||
}
|
||||
*p1 as usize
|
||||
}
|
||||
Some(pads) => {
|
||||
bail!("more pads than expected in conv1d {pads:?} {}", node.name)
|
||||
}
|
||||
};
|
||||
let strides = match strides {
|
||||
None => 1,
|
||||
Some([p]) => *p as usize,
|
||||
Some(s) => {
|
||||
bail!("more strides than expected in conv1d {s:?} {}", node.name)
|
||||
}
|
||||
};
|
||||
let dilations = match dilations {
|
||||
None => 1,
|
||||
Some([p]) => *p as usize,
|
||||
Some(s) => {
|
||||
bail!("more dilations than expected in conv1d {s:?} {}", node.name)
|
||||
}
|
||||
};
|
||||
xs.conv1d(ws, pads, strides, dilations, groups as usize)?
|
||||
}
|
||||
4 => {
|
||||
let pads = match pads {
|
||||
None => 0,
|
||||
Some([p]) => *p as usize,
|
||||
Some([p1, p2, p3, p4]) => {
|
||||
if p1 != p2 || p1 != p3 || p1 != p4 {
|
||||
bail!("pads to be the same {pads:?} {}", node.name)
|
||||
}
|
||||
*p1 as usize
|
||||
}
|
||||
Some(pads) => {
|
||||
bail!("more pads than expected in conv2d {pads:?} {}", node.name)
|
||||
}
|
||||
};
|
||||
let strides = match strides {
|
||||
None => 1,
|
||||
Some([p]) => *p as usize,
|
||||
Some([p1, p2]) => {
|
||||
if p1 != p2 {
|
||||
bail!(
|
||||
"strides to be the same on both axis {pads:?} {}",
|
||||
node.name
|
||||
)
|
||||
}
|
||||
*p1 as usize
|
||||
}
|
||||
Some(s) => {
|
||||
bail!("more strides than expected in conv2d {s:?} {}", node.name)
|
||||
}
|
||||
};
|
||||
let dilations = match dilations {
|
||||
None => 1,
|
||||
Some([p]) => *p as usize,
|
||||
Some([p1, p2]) => {
|
||||
if p1 != p2 {
|
||||
bail!(
|
||||
"dilations to be the same on both axis {pads:?} {}",
|
||||
node.name
|
||||
)
|
||||
}
|
||||
*p1 as usize
|
||||
}
|
||||
Some(s) => {
|
||||
bail!("more dilations than expected in conv2d {s:?} {}", node.name)
|
||||
}
|
||||
};
|
||||
xs.conv2d(ws, pads, strides, dilations, groups as usize)?
|
||||
}
|
||||
rank => bail!(
|
||||
"unsupported rank for weight matrix {rank} in conv {}",
|
||||
node.name
|
||||
),
|
||||
};
|
||||
let ys = if node.input.len() > 2 {
|
||||
let bs = get(&node.input[2])?;
|
||||
let mut bs_shape = vec![1; ys.rank()];
|
||||
bs_shape[1] = bs.elem_count();
|
||||
ys.broadcast_add(&bs.reshape(bs_shape)?)?
|
||||
} else {
|
||||
ys
|
||||
};
|
||||
values.insert(node.output[0].clone(), ys);
|
||||
}
|
||||
"Concat" => {
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Concat
|
||||
let inputs = node
|
||||
|
@ -188,7 +353,7 @@ pub fn simple_eval(
|
|||
.iter()
|
||||
.map(|n| Ok(get(n.as_str())?.clone()))
|
||||
.collect::<Result<Vec<Value>>>()?;
|
||||
let axis = get_attr_i("axis")?;
|
||||
let axis: i64 = *get_attr(node, "axis")?;
|
||||
let num_axis = if inputs.is_empty() {
|
||||
bail!("empty concat")
|
||||
} else {
|
||||
|
@ -264,27 +429,7 @@ pub fn simple_eval(
|
|||
let output = match value.r#type() {
|
||||
AttributeType::Tensor => {
|
||||
let t = value.t.as_ref().unwrap();
|
||||
let dims: Vec<usize> = t.dims.iter().map(|&x| x as usize).collect();
|
||||
match DataType::try_from(t.data_type) {
|
||||
Ok(dt) => match dtype(dt) {
|
||||
Some(dt) => Tensor::from_raw_buffer(
|
||||
t.raw_data.as_slice(),
|
||||
dt,
|
||||
dims.as_slice(),
|
||||
&Device::Cpu,
|
||||
)?,
|
||||
None => {
|
||||
bail!("unsupported 'value' data-type {dt:?} for {}", node.name)
|
||||
}
|
||||
},
|
||||
Err(_) => {
|
||||
bail!(
|
||||
"unsupported 'value' data-type {} for {}",
|
||||
t.data_type,
|
||||
node.name
|
||||
)
|
||||
}
|
||||
}
|
||||
get_tensor(t, &node.name)?
|
||||
}
|
||||
rtype => bail!("unsupported 'value' type {rtype:?} for {}", node.name),
|
||||
};
|
||||
|
@ -293,7 +438,7 @@ pub fn simple_eval(
|
|||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cast
|
||||
"Cast" => {
|
||||
let input = get(&node.input[0])?;
|
||||
let dt = get_attr_i("to")?;
|
||||
let dt: i64 = *get_attr(node, "to")?;
|
||||
let dtype = match DataType::try_from(dt as i32) {
|
||||
Ok(dt) => match dtype(dt) {
|
||||
Some(dt) => dt,
|
||||
|
|
Loading…
Reference in New Issue