Better tensor initialization in ONNX. (#1270)
* Better tensor initialization in ONNX. * MaxPool support. * Add AvgPool. * Get the squeezenet example to work.
This commit is contained in:
parent
b5e4f84bed
commit
39ad840a90
|
@ -96,7 +96,20 @@ fn get_tensor(t: &onnx::TensorProto, name: &str) -> Result<Tensor> {
|
|||
match DataType::try_from(t.data_type) {
|
||||
Ok(dt) => match dtype(dt) {
|
||||
Some(dt) => {
|
||||
Tensor::from_raw_buffer(t.raw_data.as_slice(), dt, dims.as_slice(), &Device::Cpu)
|
||||
if dt == DType::F32 && !t.float_data.is_empty() {
|
||||
Tensor::from_slice(&t.float_data, dims.as_slice(), &Device::Cpu)
|
||||
} else if dt == DType::F64 && !t.double_data.is_empty() {
|
||||
Tensor::from_slice(&t.double_data, dims.as_slice(), &Device::Cpu)
|
||||
} else if dt == DType::I64 && !t.int64_data.is_empty() {
|
||||
Tensor::from_slice(&t.int64_data, dims.as_slice(), &Device::Cpu)
|
||||
} else {
|
||||
Tensor::from_raw_buffer(
|
||||
t.raw_data.as_slice(),
|
||||
dt,
|
||||
dims.as_slice(),
|
||||
&Device::Cpu,
|
||||
)
|
||||
}
|
||||
}
|
||||
None => {
|
||||
bail!("unsupported 'value' data-type {dt:?} for {name}")
|
||||
|
@ -174,17 +187,22 @@ pub fn simple_eval(
|
|||
"Reshape" => {
|
||||
let input0 = get(&node.input[0])?;
|
||||
let input1 = get(&node.input[1])?.to_vec1::<i64>()?;
|
||||
// TODO: Check that there is at most a single -1, handle other neg values.
|
||||
// TODO: Check that there is at most a single -1 or 0, handle other neg values.
|
||||
let mut other_than_minus1 = 1usize;
|
||||
for &v in input1.iter() {
|
||||
if v != -1 && v != 0 {
|
||||
other_than_minus1 *= v as usize
|
||||
}
|
||||
}
|
||||
let input1 = input1
|
||||
.iter()
|
||||
.map(|&v| {
|
||||
if v == -1 {
|
||||
input0.elem_count()
|
||||
} else {
|
||||
v as usize
|
||||
}
|
||||
.enumerate()
|
||||
.map(|(idx, &v)| match v {
|
||||
-1 => Ok(input0.elem_count() / other_than_minus1),
|
||||
0 => input0.dim(idx),
|
||||
_ => Ok(v as usize),
|
||||
})
|
||||
.collect::<Vec<usize>>();
|
||||
.collect::<Result<Vec<usize>>>()?;
|
||||
let output = input0.reshape(input1)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
|
@ -235,6 +253,81 @@ pub fn simple_eval(
|
|||
};
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Dropout" => {
|
||||
let input = get(&node.input[0])?;
|
||||
// Do not apply dropout at the moment, consider that we're only doing inference.
|
||||
values.insert(node.output[0].clone(), input.clone());
|
||||
}
|
||||
"MaxPool" => {
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#MaxPool
|
||||
let dilations = get_attr_opt::<[i64]>(node, "dilations")?;
|
||||
let kernel_shape = get_attr::<[i64]>(node, "kernel_shape")?;
|
||||
let pads = get_attr_opt::<[i64]>(node, "pads")?;
|
||||
let strides = get_attr_opt::<[i64]>(node, "strides")?;
|
||||
let auto_pad = get_attr_opt::<str>(node, "auto_pad")?;
|
||||
match auto_pad {
|
||||
None | Some("NOTSET") => (),
|
||||
Some(s) => bail!("unsupported auto_pad {s}"),
|
||||
};
|
||||
if let Some(d) = dilations {
|
||||
if d.iter().any(|&v| v != 1) {
|
||||
bail!("MaxPool with dilation != 1, {dilations:?}")
|
||||
}
|
||||
}
|
||||
if let Some(d) = pads {
|
||||
if d.iter().any(|&v| v != 0) {
|
||||
bail!("MaxPool with pads != 0, {pads:?}")
|
||||
}
|
||||
}
|
||||
let xs = get(&node.input[0])?;
|
||||
let (k1, k2) = match kernel_shape {
|
||||
[k1, k2] => (*k1 as usize, *k2 as usize),
|
||||
_ => bail!("only 2d MaxPool is supported, kernel shape {kernel_shape:?}"),
|
||||
};
|
||||
let ys = match strides {
|
||||
None => xs.max_pool2d((k1, k2))?,
|
||||
Some([s1, s2]) => {
|
||||
xs.max_pool2d_with_stride((k1, k2), (*s1 as usize, *s2 as usize))?
|
||||
}
|
||||
Some(strides) => bail!("only 2d MaxPool is supported, strides {strides:?}"),
|
||||
};
|
||||
values.insert(node.output[0].clone(), ys);
|
||||
}
|
||||
"AveragePool" => {
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#AveragePool
|
||||
let dilations = get_attr_opt::<[i64]>(node, "dilations")?;
|
||||
let kernel_shape = get_attr::<[i64]>(node, "kernel_shape")?;
|
||||
let pads = get_attr_opt::<[i64]>(node, "pads")?;
|
||||
let strides = get_attr_opt::<[i64]>(node, "strides")?;
|
||||
let auto_pad = get_attr_opt::<str>(node, "auto_pad")?;
|
||||
match auto_pad {
|
||||
None | Some("NOTSET") => (),
|
||||
Some(s) => bail!("unsupported auto_pad {s}"),
|
||||
};
|
||||
if let Some(d) = dilations {
|
||||
if d.iter().any(|&v| v != 1) {
|
||||
bail!("AvgPool with dilation != 1, {dilations:?}")
|
||||
}
|
||||
}
|
||||
if let Some(d) = pads {
|
||||
if d.iter().any(|&v| v != 0) {
|
||||
bail!("AvgPool with pads != 0, {pads:?}")
|
||||
}
|
||||
}
|
||||
let xs = get(&node.input[0])?;
|
||||
let (k1, k2) = match kernel_shape {
|
||||
[k1, k2] => (*k1 as usize, *k2 as usize),
|
||||
_ => bail!("only 2d AvgPool is supported, kernel shape {kernel_shape:?}"),
|
||||
};
|
||||
let ys = match strides {
|
||||
None => xs.avg_pool2d((k1, k2))?,
|
||||
Some([s1, s2]) => {
|
||||
xs.avg_pool2d_with_stride((k1, k2), (*s1 as usize, *s2 as usize))?
|
||||
}
|
||||
Some(strides) => bail!("only 2d AvgPool is supported, strides {strides:?}"),
|
||||
};
|
||||
values.insert(node.output[0].clone(), ys);
|
||||
}
|
||||
"Conv" => {
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Conv
|
||||
let dilations = get_attr_opt::<[i64]>(node, "dilations")?;
|
||||
|
@ -453,7 +546,7 @@ pub fn simple_eval(
|
|||
let output = input.to_dtype(dtype)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
op_type => bail!("unsupported op_type {op_type} for op {}", node.name),
|
||||
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
|
||||
}
|
||||
}
|
||||
graph
|
||||
|
|
Loading…
Reference in New Issue