Better tensor initialization in ONNX. (#1270)

* Better tensor initialization in ONNX.

* MaxPool support.

* Add AvgPool.

* Get the squeezenet example to work.
This commit is contained in:
Laurent Mazare 2023-11-04 22:17:45 +01:00 committed by GitHub
parent b5e4f84bed
commit 39ad840a90
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 103 additions and 10 deletions

View File

@ -96,7 +96,20 @@ fn get_tensor(t: &onnx::TensorProto, name: &str) -> Result<Tensor> {
match DataType::try_from(t.data_type) { match DataType::try_from(t.data_type) {
Ok(dt) => match dtype(dt) { Ok(dt) => match dtype(dt) {
Some(dt) => { Some(dt) => {
Tensor::from_raw_buffer(t.raw_data.as_slice(), dt, dims.as_slice(), &Device::Cpu) if dt == DType::F32 && !t.float_data.is_empty() {
Tensor::from_slice(&t.float_data, dims.as_slice(), &Device::Cpu)
} else if dt == DType::F64 && !t.double_data.is_empty() {
Tensor::from_slice(&t.double_data, dims.as_slice(), &Device::Cpu)
} else if dt == DType::I64 && !t.int64_data.is_empty() {
Tensor::from_slice(&t.int64_data, dims.as_slice(), &Device::Cpu)
} else {
Tensor::from_raw_buffer(
t.raw_data.as_slice(),
dt,
dims.as_slice(),
&Device::Cpu,
)
}
} }
None => { None => {
bail!("unsupported 'value' data-type {dt:?} for {name}") bail!("unsupported 'value' data-type {dt:?} for {name}")
@ -174,17 +187,22 @@ pub fn simple_eval(
"Reshape" => { "Reshape" => {
let input0 = get(&node.input[0])?; let input0 = get(&node.input[0])?;
let input1 = get(&node.input[1])?.to_vec1::<i64>()?; let input1 = get(&node.input[1])?.to_vec1::<i64>()?;
// TODO: Check that there is at most a single -1, handle other neg values. // TODO: Check that there is at most a single -1 or 0, handle other neg values.
let mut other_than_minus1 = 1usize;
for &v in input1.iter() {
if v != -1 && v != 0 {
other_than_minus1 *= v as usize
}
}
let input1 = input1 let input1 = input1
.iter() .iter()
.map(|&v| { .enumerate()
if v == -1 { .map(|(idx, &v)| match v {
input0.elem_count() -1 => Ok(input0.elem_count() / other_than_minus1),
} else { 0 => input0.dim(idx),
v as usize _ => Ok(v as usize),
}
}) })
.collect::<Vec<usize>>(); .collect::<Result<Vec<usize>>>()?;
let output = input0.reshape(input1)?; let output = input0.reshape(input1)?;
values.insert(node.output[0].clone(), output); values.insert(node.output[0].clone(), output);
} }
@ -235,6 +253,81 @@ pub fn simple_eval(
}; };
values.insert(node.output[0].clone(), output); values.insert(node.output[0].clone(), output);
} }
"Dropout" => {
let input = get(&node.input[0])?;
// Do not apply dropout at the moment, consider that we're only doing inference.
values.insert(node.output[0].clone(), input.clone());
}
"MaxPool" => {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#MaxPool
let dilations = get_attr_opt::<[i64]>(node, "dilations")?;
let kernel_shape = get_attr::<[i64]>(node, "kernel_shape")?;
let pads = get_attr_opt::<[i64]>(node, "pads")?;
let strides = get_attr_opt::<[i64]>(node, "strides")?;
let auto_pad = get_attr_opt::<str>(node, "auto_pad")?;
match auto_pad {
None | Some("NOTSET") => (),
Some(s) => bail!("unsupported auto_pad {s}"),
};
if let Some(d) = dilations {
if d.iter().any(|&v| v != 1) {
bail!("MaxPool with dilation != 1, {dilations:?}")
}
}
if let Some(d) = pads {
if d.iter().any(|&v| v != 0) {
bail!("MaxPool with pads != 0, {pads:?}")
}
}
let xs = get(&node.input[0])?;
let (k1, k2) = match kernel_shape {
[k1, k2] => (*k1 as usize, *k2 as usize),
_ => bail!("only 2d MaxPool is supported, kernel shape {kernel_shape:?}"),
};
let ys = match strides {
None => xs.max_pool2d((k1, k2))?,
Some([s1, s2]) => {
xs.max_pool2d_with_stride((k1, k2), (*s1 as usize, *s2 as usize))?
}
Some(strides) => bail!("only 2d MaxPool is supported, strides {strides:?}"),
};
values.insert(node.output[0].clone(), ys);
}
"AveragePool" => {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#AveragePool
let dilations = get_attr_opt::<[i64]>(node, "dilations")?;
let kernel_shape = get_attr::<[i64]>(node, "kernel_shape")?;
let pads = get_attr_opt::<[i64]>(node, "pads")?;
let strides = get_attr_opt::<[i64]>(node, "strides")?;
let auto_pad = get_attr_opt::<str>(node, "auto_pad")?;
match auto_pad {
None | Some("NOTSET") => (),
Some(s) => bail!("unsupported auto_pad {s}"),
};
if let Some(d) = dilations {
if d.iter().any(|&v| v != 1) {
bail!("AvgPool with dilation != 1, {dilations:?}")
}
}
if let Some(d) = pads {
if d.iter().any(|&v| v != 0) {
bail!("AvgPool with pads != 0, {pads:?}")
}
}
let xs = get(&node.input[0])?;
let (k1, k2) = match kernel_shape {
[k1, k2] => (*k1 as usize, *k2 as usize),
_ => bail!("only 2d AvgPool is supported, kernel shape {kernel_shape:?}"),
};
let ys = match strides {
None => xs.avg_pool2d((k1, k2))?,
Some([s1, s2]) => {
xs.avg_pool2d_with_stride((k1, k2), (*s1 as usize, *s2 as usize))?
}
Some(strides) => bail!("only 2d AvgPool is supported, strides {strides:?}"),
};
values.insert(node.output[0].clone(), ys);
}
"Conv" => { "Conv" => {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Conv // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Conv
let dilations = get_attr_opt::<[i64]>(node, "dilations")?; let dilations = get_attr_opt::<[i64]>(node, "dilations")?;
@ -453,7 +546,7 @@ pub fn simple_eval(
let output = input.to_dtype(dtype)?; let output = input.to_dtype(dtype)?;
values.insert(node.output[0].clone(), output); values.insert(node.output[0].clone(), output);
} }
op_type => bail!("unsupported op_type {op_type} for op {}", node.name), op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
} }
} }
graph graph