diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index bc04eb00..4d44bd8e 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -96,7 +96,20 @@ fn get_tensor(t: &onnx::TensorProto, name: &str) -> Result { match DataType::try_from(t.data_type) { Ok(dt) => match dtype(dt) { Some(dt) => { - Tensor::from_raw_buffer(t.raw_data.as_slice(), dt, dims.as_slice(), &Device::Cpu) + if dt == DType::F32 && !t.float_data.is_empty() { + Tensor::from_slice(&t.float_data, dims.as_slice(), &Device::Cpu) + } else if dt == DType::F64 && !t.double_data.is_empty() { + Tensor::from_slice(&t.double_data, dims.as_slice(), &Device::Cpu) + } else if dt == DType::I64 && !t.int64_data.is_empty() { + Tensor::from_slice(&t.int64_data, dims.as_slice(), &Device::Cpu) + } else { + Tensor::from_raw_buffer( + t.raw_data.as_slice(), + dt, + dims.as_slice(), + &Device::Cpu, + ) + } } None => { bail!("unsupported 'value' data-type {dt:?} for {name}") @@ -174,17 +187,22 @@ pub fn simple_eval( "Reshape" => { let input0 = get(&node.input[0])?; let input1 = get(&node.input[1])?.to_vec1::()?; - // TODO: Check that there is at most a single -1, handle other neg values. + // TODO: Check that there is at most a single -1 or 0, handle other neg values. + let mut other_than_minus1 = 1usize; + for &v in input1.iter() { + if v != -1 && v != 0 { + other_than_minus1 *= v as usize + } + } let input1 = input1 .iter() - .map(|&v| { - if v == -1 { - input0.elem_count() - } else { - v as usize - } + .enumerate() + .map(|(idx, &v)| match v { + -1 => Ok(input0.elem_count() / other_than_minus1), + 0 => input0.dim(idx), + _ => Ok(v as usize), }) - .collect::>(); + .collect::>>()?; let output = input0.reshape(input1)?; values.insert(node.output[0].clone(), output); } @@ -235,6 +253,81 @@ pub fn simple_eval( }; values.insert(node.output[0].clone(), output); } + "Dropout" => { + let input = get(&node.input[0])?; + // Do not apply dropout at the moment, consider that we're only doing inference. + values.insert(node.output[0].clone(), input.clone()); + } + "MaxPool" => { + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#MaxPool + let dilations = get_attr_opt::<[i64]>(node, "dilations")?; + let kernel_shape = get_attr::<[i64]>(node, "kernel_shape")?; + let pads = get_attr_opt::<[i64]>(node, "pads")?; + let strides = get_attr_opt::<[i64]>(node, "strides")?; + let auto_pad = get_attr_opt::(node, "auto_pad")?; + match auto_pad { + None | Some("NOTSET") => (), + Some(s) => bail!("unsupported auto_pad {s}"), + }; + if let Some(d) = dilations { + if d.iter().any(|&v| v != 1) { + bail!("MaxPool with dilation != 1, {dilations:?}") + } + } + if let Some(d) = pads { + if d.iter().any(|&v| v != 0) { + bail!("MaxPool with pads != 0, {pads:?}") + } + } + let xs = get(&node.input[0])?; + let (k1, k2) = match kernel_shape { + [k1, k2] => (*k1 as usize, *k2 as usize), + _ => bail!("only 2d MaxPool is supported, kernel shape {kernel_shape:?}"), + }; + let ys = match strides { + None => xs.max_pool2d((k1, k2))?, + Some([s1, s2]) => { + xs.max_pool2d_with_stride((k1, k2), (*s1 as usize, *s2 as usize))? + } + Some(strides) => bail!("only 2d MaxPool is supported, strides {strides:?}"), + }; + values.insert(node.output[0].clone(), ys); + } + "AveragePool" => { + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#AveragePool + let dilations = get_attr_opt::<[i64]>(node, "dilations")?; + let kernel_shape = get_attr::<[i64]>(node, "kernel_shape")?; + let pads = get_attr_opt::<[i64]>(node, "pads")?; + let strides = get_attr_opt::<[i64]>(node, "strides")?; + let auto_pad = get_attr_opt::(node, "auto_pad")?; + match auto_pad { + None | Some("NOTSET") => (), + Some(s) => bail!("unsupported auto_pad {s}"), + }; + if let Some(d) = dilations { + if d.iter().any(|&v| v != 1) { + bail!("AvgPool with dilation != 1, {dilations:?}") + } + } + if let Some(d) = pads { + if d.iter().any(|&v| v != 0) { + bail!("AvgPool with pads != 0, {pads:?}") + } + } + let xs = get(&node.input[0])?; + let (k1, k2) = match kernel_shape { + [k1, k2] => (*k1 as usize, *k2 as usize), + _ => bail!("only 2d AvgPool is supported, kernel shape {kernel_shape:?}"), + }; + let ys = match strides { + None => xs.avg_pool2d((k1, k2))?, + Some([s1, s2]) => { + xs.avg_pool2d_with_stride((k1, k2), (*s1 as usize, *s2 as usize))? + } + Some(strides) => bail!("only 2d AvgPool is supported, strides {strides:?}"), + }; + values.insert(node.output[0].clone(), ys); + } "Conv" => { // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Conv let dilations = get_attr_opt::<[i64]>(node, "dilations")?; @@ -453,7 +546,7 @@ pub fn simple_eval( let output = input.to_dtype(dtype)?; values.insert(node.output[0].clone(), output); } - op_type => bail!("unsupported op_type {op_type} for op {}", node.name), + op_type => bail!("unsupported op_type {op_type} for op {node:?}"), } } graph