Refactor the onnx attribute getters. (#1268)

* Refactor the onnx attribute getters.

* Add get-attr-opt.

* Add support for convolutions.

* Add support for convolutions.
This commit is contained in:
Laurent Mazare 2023-11-04 21:31:48 +01:00 committed by GitHub
parent 7051fb8098
commit b5e4f84bed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 218 additions and 73 deletions

View File

@ -1,4 +1,5 @@
use crate::onnx;
use crate::onnx::attribute_proto::AttributeType;
use crate::onnx::tensor_proto::DataType;
use candle::{bail, DType, Device, Result, Tensor};
use std::collections::HashMap;
@ -17,6 +18,96 @@ pub fn dtype(dt: DataType) -> Option<DType> {
}
}
trait Attr {
const TYPE: AttributeType;
fn get(attr: &onnx::AttributeProto) -> Result<&Self>;
}
impl Attr for i64 {
const TYPE: AttributeType = AttributeType::Int;
fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
Ok(&attr.i)
}
}
impl Attr for [i64] {
const TYPE: AttributeType = AttributeType::Ints;
fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
Ok(attr.ints.as_slice())
}
}
impl Attr for str {
const TYPE: AttributeType = AttributeType::String;
fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
std::str::from_utf8(&attr.s).map_err(candle::Error::wrap)
}
}
fn get_attr_<'a>(node: &'a onnx::NodeProto, name: &str) -> Result<&'a onnx::AttributeProto> {
match node.attribute.iter().find(|attr| attr.name == name) {
None => {
bail!(
"cannot find the '{name}' attribute in '{}' for {}",
node.op_type,
node.name
)
}
Some(dt) => Ok(dt),
}
}
fn get_attr<'a, T: Attr + ?Sized>(node: &'a onnx::NodeProto, name: &str) -> Result<&'a T> {
let attr = get_attr_(node, name)?;
if attr.r#type() != T::TYPE {
bail!(
"unsupported type {:?} for '{name}' attribute in '{}' for {}",
attr.r#type,
node.op_type,
node.name
)
}
T::get(attr)
}
fn get_attr_opt<'a, T: Attr + ?Sized>(
node: &'a onnx::NodeProto,
name: &str,
) -> Result<Option<&'a T>> {
match node.attribute.iter().find(|attr| attr.name == name) {
None => Ok(None),
Some(attr) => {
if attr.r#type() != T::TYPE {
bail!(
"unsupported type {:?} for '{name}' attribute in '{}' for {}",
attr.r#type,
node.op_type,
node.name
)
}
let val = T::get(attr)?;
Ok(Some(val))
}
}
}
fn get_tensor(t: &onnx::TensorProto, name: &str) -> Result<Tensor> {
let dims: Vec<usize> = t.dims.iter().map(|&x| x as usize).collect();
match DataType::try_from(t.data_type) {
Ok(dt) => match dtype(dt) {
Some(dt) => {
Tensor::from_raw_buffer(t.raw_data.as_slice(), dt, dims.as_slice(), &Device::Cpu)
}
None => {
bail!("unsupported 'value' data-type {dt:?} for {name}")
}
},
Err(_) => {
bail!("unsupported 'value' data-type {} for {name}", t.data_type,)
}
}
}
// This function provides a direct evaluation of the proto.
// Longer-term, we should first convert the proto to an intermediate representation of the compute
// graph so as to make multiple evaluations more efficient.
@ -26,59 +117,22 @@ pub fn simple_eval(
model: &onnx::ModelProto,
inputs: HashMap<String, Value>,
) -> Result<HashMap<String, Value>> {
use crate::onnx::attribute_proto::AttributeType;
let graph = match &model.graph {
None => bail!("no graph defined in proto"),
Some(graph) => graph,
};
// TODO: validate the inputs.
let mut values = inputs;
for t in graph.initializer.iter() {
let tensor = get_tensor(t, t.name.as_str())?;
values.insert(t.name.to_string(), tensor);
}
// The nodes are topologically sorted so we can just process them in order.
for node in graph.node.iter() {
let get = |input_name: &str| match values.get(input_name) {
Some(value) => Ok(value),
None => bail!("cannot find {input_name} for op {}", node.name),
};
let get_attr_i = |name: &str| match node.attribute.iter().find(|attr| attr.name == name) {
None => {
bail!(
"cannot find the '{name}' attribute in '{}' for {}",
node.op_type,
node.name
)
}
Some(dt) => {
match dt.r#type() {
AttributeType::Int => (),
rtype => bail!(
"unsupported type {rtype:?} for '{name}' attribute in '{}' for {}",
node.op_type,
node.name
),
}
Ok(dt.i)
}
};
let get_attr_is = |name: &str| match node.attribute.iter().find(|attr| attr.name == name) {
None => {
bail!(
"cannot find the '{name}' attribute in '{}' for {}",
node.op_type,
node.name
)
}
Some(dt) => {
match dt.r#type() {
AttributeType::Ints => (),
rtype => bail!(
"unsupported type {rtype:?} for '{name}' attribute in '{}' for {}",
node.op_type,
node.name
),
}
Ok(dt.ints.as_slice())
}
};
// TODO: Validate node.input for each operator.
match node.op_type.as_str() {
"Add" => {
@ -136,9 +190,9 @@ pub fn simple_eval(
}
"LogSoftmax" => {
let input = get(&node.input[0])?;
let output = match get_attr_i("axis") {
Err(_) => candle_nn::ops::softmax_last_dim(input)?,
Ok(axis) => {
let output = match get_attr_opt::<i64>(node, "axis")? {
None => candle_nn::ops::softmax_last_dim(input)?,
Some(&axis) => {
let num_axis = input.rank() as i64;
let axis = if axis >= 0 {
axis as usize
@ -154,9 +208,9 @@ pub fn simple_eval(
}
"Softmax" => {
let input = get(&node.input[0])?;
let output = match get_attr_i("axis") {
Err(_) => candle_nn::ops::softmax_last_dim(input)?,
Ok(axis) => {
let output = match get_attr_opt::<i64>(node, "axis")? {
None => candle_nn::ops::softmax_last_dim(input)?,
Some(&axis) => {
let num_axis = input.rank() as i64;
let axis = if axis >= 0 {
axis as usize
@ -172,15 +226,126 @@ pub fn simple_eval(
}
"Transpose" => {
let input = get(&node.input[0])?;
let output = match get_attr_is("perm") {
Err(_) => input.t()?,
Ok(perm) => {
let output = match get_attr_opt::<[i64]>(node, "perm")? {
None => input.t()?,
Some(perm) => {
let perm = perm.iter().map(|&v| v as usize).collect::<Vec<_>>();
input.permute(perm)?
}
};
values.insert(node.output[0].clone(), output);
}
"Conv" => {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Conv
let dilations = get_attr_opt::<[i64]>(node, "dilations")?;
let groups = get_attr_opt::<i64>(node, "group")?.copied().unwrap_or(1);
let _kernel_shape = get_attr_opt::<[i64]>(node, "kernel_shape")?;
let pads = get_attr_opt::<[i64]>(node, "pads")?;
let strides = get_attr_opt::<[i64]>(node, "strides")?;
let auto_pad = get_attr_opt::<str>(node, "auto_pad")?;
match auto_pad {
None | Some("NOTSET") => (),
Some(s) => bail!("unsupported auto_pad {s}"),
};
let xs = get(&node.input[0])?;
let ws = get(&node.input[1])?;
let ys = match ws.rank() {
3 => {
let pads = match pads {
None => 0,
Some([p]) => *p as usize,
Some([p1, p2]) => {
if p1 != p2 {
bail!(
"left and right pad ({p1} <> {p2}) have to be the same {}",
node.name
)
}
*p1 as usize
}
Some(pads) => {
bail!("more pads than expected in conv1d {pads:?} {}", node.name)
}
};
let strides = match strides {
None => 1,
Some([p]) => *p as usize,
Some(s) => {
bail!("more strides than expected in conv1d {s:?} {}", node.name)
}
};
let dilations = match dilations {
None => 1,
Some([p]) => *p as usize,
Some(s) => {
bail!("more dilations than expected in conv1d {s:?} {}", node.name)
}
};
xs.conv1d(ws, pads, strides, dilations, groups as usize)?
}
4 => {
let pads = match pads {
None => 0,
Some([p]) => *p as usize,
Some([p1, p2, p3, p4]) => {
if p1 != p2 || p1 != p3 || p1 != p4 {
bail!("pads to be the same {pads:?} {}", node.name)
}
*p1 as usize
}
Some(pads) => {
bail!("more pads than expected in conv2d {pads:?} {}", node.name)
}
};
let strides = match strides {
None => 1,
Some([p]) => *p as usize,
Some([p1, p2]) => {
if p1 != p2 {
bail!(
"strides to be the same on both axis {pads:?} {}",
node.name
)
}
*p1 as usize
}
Some(s) => {
bail!("more strides than expected in conv2d {s:?} {}", node.name)
}
};
let dilations = match dilations {
None => 1,
Some([p]) => *p as usize,
Some([p1, p2]) => {
if p1 != p2 {
bail!(
"dilations to be the same on both axis {pads:?} {}",
node.name
)
}
*p1 as usize
}
Some(s) => {
bail!("more dilations than expected in conv2d {s:?} {}", node.name)
}
};
xs.conv2d(ws, pads, strides, dilations, groups as usize)?
}
rank => bail!(
"unsupported rank for weight matrix {rank} in conv {}",
node.name
),
};
let ys = if node.input.len() > 2 {
let bs = get(&node.input[2])?;
let mut bs_shape = vec![1; ys.rank()];
bs_shape[1] = bs.elem_count();
ys.broadcast_add(&bs.reshape(bs_shape)?)?
} else {
ys
};
values.insert(node.output[0].clone(), ys);
}
"Concat" => {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Concat
let inputs = node
@ -188,7 +353,7 @@ pub fn simple_eval(
.iter()
.map(|n| Ok(get(n.as_str())?.clone()))
.collect::<Result<Vec<Value>>>()?;
let axis = get_attr_i("axis")?;
let axis: i64 = *get_attr(node, "axis")?;
let num_axis = if inputs.is_empty() {
bail!("empty concat")
} else {
@ -264,27 +429,7 @@ pub fn simple_eval(
let output = match value.r#type() {
AttributeType::Tensor => {
let t = value.t.as_ref().unwrap();
let dims: Vec<usize> = t.dims.iter().map(|&x| x as usize).collect();
match DataType::try_from(t.data_type) {
Ok(dt) => match dtype(dt) {
Some(dt) => Tensor::from_raw_buffer(
t.raw_data.as_slice(),
dt,
dims.as_slice(),
&Device::Cpu,
)?,
None => {
bail!("unsupported 'value' data-type {dt:?} for {}", node.name)
}
},
Err(_) => {
bail!(
"unsupported 'value' data-type {} for {}",
t.data_type,
node.name
)
}
}
get_tensor(t, &node.name)?
}
rtype => bail!("unsupported 'value' type {rtype:?} for {}", node.name),
};
@ -293,7 +438,7 @@ pub fn simple_eval(
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cast
"Cast" => {
let input = get(&node.input[0])?;
let dt = get_attr_i("to")?;
let dt: i64 = *get_attr(node, "to")?;
let dtype = match DataType::try_from(dt as i32) {
Ok(dt) => match dtype(dt) {
Some(dt) => dt,