Better tensor initialization in ONNX. (#1270)
* Better tensor initialization in ONNX. * MaxPool support. * Add AvgPool. * Get the squeezenet example to work.
This commit is contained in:
parent
b5e4f84bed
commit
39ad840a90
|
@ -96,7 +96,20 @@ fn get_tensor(t: &onnx::TensorProto, name: &str) -> Result<Tensor> {
|
||||||
match DataType::try_from(t.data_type) {
|
match DataType::try_from(t.data_type) {
|
||||||
Ok(dt) => match dtype(dt) {
|
Ok(dt) => match dtype(dt) {
|
||||||
Some(dt) => {
|
Some(dt) => {
|
||||||
Tensor::from_raw_buffer(t.raw_data.as_slice(), dt, dims.as_slice(), &Device::Cpu)
|
if dt == DType::F32 && !t.float_data.is_empty() {
|
||||||
|
Tensor::from_slice(&t.float_data, dims.as_slice(), &Device::Cpu)
|
||||||
|
} else if dt == DType::F64 && !t.double_data.is_empty() {
|
||||||
|
Tensor::from_slice(&t.double_data, dims.as_slice(), &Device::Cpu)
|
||||||
|
} else if dt == DType::I64 && !t.int64_data.is_empty() {
|
||||||
|
Tensor::from_slice(&t.int64_data, dims.as_slice(), &Device::Cpu)
|
||||||
|
} else {
|
||||||
|
Tensor::from_raw_buffer(
|
||||||
|
t.raw_data.as_slice(),
|
||||||
|
dt,
|
||||||
|
dims.as_slice(),
|
||||||
|
&Device::Cpu,
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
None => {
|
None => {
|
||||||
bail!("unsupported 'value' data-type {dt:?} for {name}")
|
bail!("unsupported 'value' data-type {dt:?} for {name}")
|
||||||
|
@ -174,17 +187,22 @@ pub fn simple_eval(
|
||||||
"Reshape" => {
|
"Reshape" => {
|
||||||
let input0 = get(&node.input[0])?;
|
let input0 = get(&node.input[0])?;
|
||||||
let input1 = get(&node.input[1])?.to_vec1::<i64>()?;
|
let input1 = get(&node.input[1])?.to_vec1::<i64>()?;
|
||||||
// TODO: Check that there is at most a single -1, handle other neg values.
|
// TODO: Check that there is at most a single -1 or 0, handle other neg values.
|
||||||
|
let mut other_than_minus1 = 1usize;
|
||||||
|
for &v in input1.iter() {
|
||||||
|
if v != -1 && v != 0 {
|
||||||
|
other_than_minus1 *= v as usize
|
||||||
|
}
|
||||||
|
}
|
||||||
let input1 = input1
|
let input1 = input1
|
||||||
.iter()
|
.iter()
|
||||||
.map(|&v| {
|
.enumerate()
|
||||||
if v == -1 {
|
.map(|(idx, &v)| match v {
|
||||||
input0.elem_count()
|
-1 => Ok(input0.elem_count() / other_than_minus1),
|
||||||
} else {
|
0 => input0.dim(idx),
|
||||||
v as usize
|
_ => Ok(v as usize),
|
||||||
}
|
|
||||||
})
|
})
|
||||||
.collect::<Vec<usize>>();
|
.collect::<Result<Vec<usize>>>()?;
|
||||||
let output = input0.reshape(input1)?;
|
let output = input0.reshape(input1)?;
|
||||||
values.insert(node.output[0].clone(), output);
|
values.insert(node.output[0].clone(), output);
|
||||||
}
|
}
|
||||||
|
@ -235,6 +253,81 @@ pub fn simple_eval(
|
||||||
};
|
};
|
||||||
values.insert(node.output[0].clone(), output);
|
values.insert(node.output[0].clone(), output);
|
||||||
}
|
}
|
||||||
|
"Dropout" => {
|
||||||
|
let input = get(&node.input[0])?;
|
||||||
|
// Do not apply dropout at the moment, consider that we're only doing inference.
|
||||||
|
values.insert(node.output[0].clone(), input.clone());
|
||||||
|
}
|
||||||
|
"MaxPool" => {
|
||||||
|
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#MaxPool
|
||||||
|
let dilations = get_attr_opt::<[i64]>(node, "dilations")?;
|
||||||
|
let kernel_shape = get_attr::<[i64]>(node, "kernel_shape")?;
|
||||||
|
let pads = get_attr_opt::<[i64]>(node, "pads")?;
|
||||||
|
let strides = get_attr_opt::<[i64]>(node, "strides")?;
|
||||||
|
let auto_pad = get_attr_opt::<str>(node, "auto_pad")?;
|
||||||
|
match auto_pad {
|
||||||
|
None | Some("NOTSET") => (),
|
||||||
|
Some(s) => bail!("unsupported auto_pad {s}"),
|
||||||
|
};
|
||||||
|
if let Some(d) = dilations {
|
||||||
|
if d.iter().any(|&v| v != 1) {
|
||||||
|
bail!("MaxPool with dilation != 1, {dilations:?}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let Some(d) = pads {
|
||||||
|
if d.iter().any(|&v| v != 0) {
|
||||||
|
bail!("MaxPool with pads != 0, {pads:?}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let xs = get(&node.input[0])?;
|
||||||
|
let (k1, k2) = match kernel_shape {
|
||||||
|
[k1, k2] => (*k1 as usize, *k2 as usize),
|
||||||
|
_ => bail!("only 2d MaxPool is supported, kernel shape {kernel_shape:?}"),
|
||||||
|
};
|
||||||
|
let ys = match strides {
|
||||||
|
None => xs.max_pool2d((k1, k2))?,
|
||||||
|
Some([s1, s2]) => {
|
||||||
|
xs.max_pool2d_with_stride((k1, k2), (*s1 as usize, *s2 as usize))?
|
||||||
|
}
|
||||||
|
Some(strides) => bail!("only 2d MaxPool is supported, strides {strides:?}"),
|
||||||
|
};
|
||||||
|
values.insert(node.output[0].clone(), ys);
|
||||||
|
}
|
||||||
|
"AveragePool" => {
|
||||||
|
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#AveragePool
|
||||||
|
let dilations = get_attr_opt::<[i64]>(node, "dilations")?;
|
||||||
|
let kernel_shape = get_attr::<[i64]>(node, "kernel_shape")?;
|
||||||
|
let pads = get_attr_opt::<[i64]>(node, "pads")?;
|
||||||
|
let strides = get_attr_opt::<[i64]>(node, "strides")?;
|
||||||
|
let auto_pad = get_attr_opt::<str>(node, "auto_pad")?;
|
||||||
|
match auto_pad {
|
||||||
|
None | Some("NOTSET") => (),
|
||||||
|
Some(s) => bail!("unsupported auto_pad {s}"),
|
||||||
|
};
|
||||||
|
if let Some(d) = dilations {
|
||||||
|
if d.iter().any(|&v| v != 1) {
|
||||||
|
bail!("AvgPool with dilation != 1, {dilations:?}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let Some(d) = pads {
|
||||||
|
if d.iter().any(|&v| v != 0) {
|
||||||
|
bail!("AvgPool with pads != 0, {pads:?}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let xs = get(&node.input[0])?;
|
||||||
|
let (k1, k2) = match kernel_shape {
|
||||||
|
[k1, k2] => (*k1 as usize, *k2 as usize),
|
||||||
|
_ => bail!("only 2d AvgPool is supported, kernel shape {kernel_shape:?}"),
|
||||||
|
};
|
||||||
|
let ys = match strides {
|
||||||
|
None => xs.avg_pool2d((k1, k2))?,
|
||||||
|
Some([s1, s2]) => {
|
||||||
|
xs.avg_pool2d_with_stride((k1, k2), (*s1 as usize, *s2 as usize))?
|
||||||
|
}
|
||||||
|
Some(strides) => bail!("only 2d AvgPool is supported, strides {strides:?}"),
|
||||||
|
};
|
||||||
|
values.insert(node.output[0].clone(), ys);
|
||||||
|
}
|
||||||
"Conv" => {
|
"Conv" => {
|
||||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Conv
|
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Conv
|
||||||
let dilations = get_attr_opt::<[i64]>(node, "dilations")?;
|
let dilations = get_attr_opt::<[i64]>(node, "dilations")?;
|
||||||
|
@ -453,7 +546,7 @@ pub fn simple_eval(
|
||||||
let output = input.to_dtype(dtype)?;
|
let output = input.to_dtype(dtype)?;
|
||||||
values.insert(node.output[0].clone(), output);
|
values.insert(node.output[0].clone(), output);
|
||||||
}
|
}
|
||||||
op_type => bail!("unsupported op_type {op_type} for op {}", node.name),
|
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
graph
|
graph
|
||||||
|
|
Loading…
Reference in New Issue