PyO3: Add optional `candle.onnx` module (#1282)
* Start onnx integration * Merge remote-tracking branch 'upstream/main' into feat/pyo3-onnx * Implement ONNXModel * `fmt` * add `onnx` flag to python ci * Pin `protoc` to `25.0` * Setup `protoc` in wheel builds * Build wheels with `onnx` * Install `protoc` in manylinux containers * `apt` -> `yum` * Download `protoc` via bash script * Back to `manylinux: auto` * Disable `onnx` builds for linux
This commit is contained in:
parent
7920b45c8a
commit
f3a4f3db76
Binary file not shown.
|
@ -39,6 +39,12 @@ jobs:
|
|||
path: ~/.cargo/registry
|
||||
key: ${{ runner.os }}-cargo-registry-${{ hashFiles('**/Cargo.lock') }}
|
||||
|
||||
- name: Install Protoc
|
||||
uses: arduino/setup-protoc@v2
|
||||
with:
|
||||
version: "25.0"
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Install
|
||||
working-directory: ./candle-pyo3
|
||||
run: |
|
||||
|
@ -46,7 +52,7 @@ jobs:
|
|||
source .env/bin/activate
|
||||
pip install -U pip
|
||||
pip install pytest maturin black
|
||||
python -m maturin develop -r
|
||||
python -m maturin develop -r --features onnx
|
||||
|
||||
- name: Check style
|
||||
working-directory: ./candle-pyo3
|
||||
|
|
|
@ -98,7 +98,7 @@ fn get_attr_opt<'a, T: Attr + ?Sized>(
|
|||
}
|
||||
}
|
||||
|
||||
fn get_tensor(t: &onnx::TensorProto, name: &str) -> Result<Tensor> {
|
||||
pub fn get_tensor(t: &onnx::TensorProto, name: &str) -> Result<Tensor> {
|
||||
let dims: Vec<usize> = t.dims.iter().map(|&x| x as usize).collect();
|
||||
match DataType::try_from(t.data_type) {
|
||||
Ok(DataType::Int32) => {
|
||||
|
|
|
@ -5,7 +5,7 @@ pub mod onnx {
|
|||
include!(concat!(env!("OUT_DIR"), "/onnx.rs"));
|
||||
}
|
||||
|
||||
mod eval;
|
||||
pub mod eval;
|
||||
pub use eval::{dtype, simple_eval};
|
||||
|
||||
pub fn read_file<P: AsRef<std::path::Path>>(p: P) -> Result<onnx::ModelProto> {
|
||||
|
|
|
@ -17,6 +17,7 @@ crate-type = ["cdylib"]
|
|||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.3.0" }
|
||||
candle-onnx = {path= "../candle-onnx", version = "0.3.0", optional = true}
|
||||
half = { workspace = true }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
pyo3 = { version = "0.20.0", features = ["extension-module", "abi3-py38"] }
|
||||
|
@ -29,3 +30,5 @@ default = []
|
|||
accelerate = ["dep:accelerate-src", "candle/accelerate"]
|
||||
cuda = ["candle/cuda"]
|
||||
mkl = ["dep:intel-mkl-src","candle/mkl"]
|
||||
onnx = ["dep:candle-onnx"]
|
||||
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
# Generated content DO NOT EDIT
|
||||
from .. import onnx
|
||||
|
||||
ONNXModel = onnx.ONNXModel
|
||||
ONNXTensorDescription = onnx.ONNXTensorDescription
|
|
@ -0,0 +1,89 @@
|
|||
# Generated content DO NOT EDIT
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
|
||||
from os import PathLike
|
||||
from candle.typing import _ArrayLike, Device, Scalar, Index, Shape
|
||||
from candle import Tensor, DType, QTensor
|
||||
|
||||
class ONNXModel:
|
||||
"""
|
||||
A wrapper around an ONNX model.
|
||||
"""
|
||||
|
||||
def __init__(self, path: str):
|
||||
pass
|
||||
@property
|
||||
def doc_string(self) -> str:
|
||||
"""
|
||||
The doc string of the model.
|
||||
"""
|
||||
pass
|
||||
@property
|
||||
def domain(self) -> str:
|
||||
"""
|
||||
The domain of the operator set of the model.
|
||||
"""
|
||||
pass
|
||||
def initializers(self) -> Dict[str, Tensor]:
|
||||
"""
|
||||
Get the weights of the model.
|
||||
"""
|
||||
pass
|
||||
@property
|
||||
def inputs(self) -> Optional[Dict[str, ONNXTensorDescription]]:
|
||||
"""
|
||||
The inputs of the model.
|
||||
"""
|
||||
pass
|
||||
@property
|
||||
def ir_version(self) -> int:
|
||||
"""
|
||||
The version of the IR this model targets.
|
||||
"""
|
||||
pass
|
||||
@property
|
||||
def model_version(self) -> int:
|
||||
"""
|
||||
The version of the model.
|
||||
"""
|
||||
pass
|
||||
@property
|
||||
def outputs(self) -> Optional[Dict[str, ONNXTensorDescription]]:
|
||||
"""
|
||||
The outputs of the model.
|
||||
"""
|
||||
pass
|
||||
@property
|
||||
def producer_name(self) -> str:
|
||||
"""
|
||||
The producer of the model.
|
||||
"""
|
||||
pass
|
||||
@property
|
||||
def producer_version(self) -> str:
|
||||
"""
|
||||
The version of the producer of the model.
|
||||
"""
|
||||
pass
|
||||
def run(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||
"""
|
||||
Run the model on the given inputs.
|
||||
"""
|
||||
pass
|
||||
|
||||
class ONNXTensorDescription:
|
||||
"""
|
||||
A wrapper around an ONNX tensor description.
|
||||
"""
|
||||
|
||||
@property
|
||||
def dtype(self) -> DType:
|
||||
"""
|
||||
The data type of the tensor.
|
||||
"""
|
||||
pass
|
||||
@property
|
||||
def shape(self) -> Tuple[Union[int, str, Any]]:
|
||||
"""
|
||||
The shape of the tensor.
|
||||
"""
|
||||
pass
|
|
@ -19,12 +19,14 @@ extern crate accelerate_src;
|
|||
|
||||
use ::candle::{quantized::QTensor, DType, Device, Tensor, WithDType};
|
||||
|
||||
mod utils;
|
||||
use utils::wrap_err;
|
||||
|
||||
mod shape;
|
||||
use shape::{PyShape, PyShapeWithHole};
|
||||
|
||||
pub fn wrap_err(err: ::candle::Error) -> PyErr {
|
||||
PyErr::new::<PyValueError, _>(format!("{err:?}"))
|
||||
}
|
||||
#[cfg(feature = "onnx")]
|
||||
mod onnx;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[pyclass(name = "Tensor")]
|
||||
|
@ -1559,6 +1561,14 @@ fn candle_functional_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(feature = "onnx")]
|
||||
fn candle_onnx_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
use onnx::{PyONNXModel, PyONNXTensorDescriptor};
|
||||
m.add_class::<PyONNXModel>()?;
|
||||
m.add_class::<PyONNXTensorDescriptor>()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[pymodule]
|
||||
fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
let utils = PyModule::new(py, "utils")?;
|
||||
|
@ -1567,6 +1577,12 @@ fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
|||
let nn = PyModule::new(py, "functional")?;
|
||||
candle_functional_m(py, nn)?;
|
||||
m.add_submodule(nn)?;
|
||||
#[cfg(feature = "onnx")]
|
||||
{
|
||||
let onnx = PyModule::new(py, "onnx")?;
|
||||
candle_onnx_m(py, onnx)?;
|
||||
m.add_submodule(onnx)?;
|
||||
}
|
||||
m.add_class::<PyTensor>()?;
|
||||
m.add_class::<PyQTensor>()?;
|
||||
m.add_class::<PyDType>()?;
|
||||
|
|
|
@ -0,0 +1,212 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use crate::utils::wrap_err;
|
||||
use crate::{PyDType, PyTensor};
|
||||
use candle_onnx::eval::{dtype, get_tensor, simple_eval};
|
||||
use candle_onnx::onnx::tensor_proto::DataType;
|
||||
use candle_onnx::onnx::tensor_shape_proto::dimension::Value;
|
||||
use candle_onnx::onnx::type_proto::{Tensor as ONNXTensor, Value as ONNXValue};
|
||||
use candle_onnx::onnx::{ModelProto, ValueInfoProto};
|
||||
use pyo3::exceptions::PyValueError;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::{PyList, PyTuple};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[pyclass(name = "ONNXTensorDescription")]
|
||||
/// A wrapper around an ONNX tensor description.
|
||||
pub struct PyONNXTensorDescriptor(ONNXTensor);
|
||||
|
||||
#[pymethods]
|
||||
impl PyONNXTensorDescriptor {
|
||||
#[getter]
|
||||
/// The data type of the tensor.
|
||||
/// &RETURNS&: DType
|
||||
fn dtype(&self) -> PyResult<PyDType> {
|
||||
match DataType::try_from(self.0.elem_type) {
|
||||
Ok(dt) => match dtype(dt) {
|
||||
Some(dt) => Ok(PyDType(dt)),
|
||||
None => Err(PyValueError::new_err(format!(
|
||||
"unsupported 'value' data-type {dt:?}"
|
||||
))),
|
||||
},
|
||||
type_ => Err(PyValueError::new_err(format!(
|
||||
"unsupported input type {type_:?}"
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
#[getter]
|
||||
/// The shape of the tensor.
|
||||
/// &RETURNS&: Tuple[Union[int,str,Any]]
|
||||
fn shape(&self, py: Python) -> PyResult<Py<PyTuple>> {
|
||||
let shape = PyList::empty(py);
|
||||
if let Some(d) = &self.0.shape {
|
||||
for dim in d.dim.iter() {
|
||||
if let Some(value) = &dim.value {
|
||||
match value {
|
||||
Value::DimValue(v) => shape.append(*v)?,
|
||||
Value::DimParam(s) => shape.append(s.clone())?,
|
||||
};
|
||||
} else {
|
||||
return Err(PyValueError::new_err("None value in shape"));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(shape.to_tuple().into())
|
||||
}
|
||||
|
||||
fn __repr__(&self, py: Python) -> String {
|
||||
match (self.shape(py), self.dtype()) {
|
||||
(Ok(shape), Ok(dtype)) => format!(
|
||||
"TensorDescriptor[shape: {:?}, dtype: {:?}]",
|
||||
shape.to_string(),
|
||||
dtype.__str__()
|
||||
),
|
||||
(Err(_), Err(_)) => "TensorDescriptor[shape: unknown, dtype: unknown]".to_string(),
|
||||
(Err(_), Ok(dtype)) => format!(
|
||||
"TensorDescriptor[shape: unknown, dtype: {:?}]",
|
||||
dtype.__str__()
|
||||
),
|
||||
(Ok(shape), Err(_)) => format!(
|
||||
"TensorDescriptor[shape: {:?}, dtype: unknown]",
|
||||
shape.to_string()
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
fn __str__(&self, py: Python) -> String {
|
||||
self.__repr__(py)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[pyclass(name = "ONNXModel")]
|
||||
/// A wrapper around an ONNX model.
|
||||
pub struct PyONNXModel(ModelProto);
|
||||
|
||||
fn extract_tensor_descriptions(
|
||||
value_infos: &[ValueInfoProto],
|
||||
) -> HashMap<String, PyONNXTensorDescriptor> {
|
||||
let mut map = HashMap::new();
|
||||
for value_info in value_infos.iter() {
|
||||
let input_type = match &value_info.r#type {
|
||||
Some(input_type) => input_type,
|
||||
None => continue,
|
||||
};
|
||||
let input_type = match &input_type.value {
|
||||
Some(input_type) => input_type,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
let tensor_type: &ONNXTensor = match input_type {
|
||||
ONNXValue::TensorType(tt) => tt,
|
||||
_ => continue,
|
||||
};
|
||||
map.insert(
|
||||
value_info.name.to_string(),
|
||||
PyONNXTensorDescriptor(tensor_type.clone()),
|
||||
);
|
||||
}
|
||||
map
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl PyONNXModel {
|
||||
#[new]
|
||||
#[pyo3(text_signature = "(self, path:str)")]
|
||||
/// Load an ONNX model from the given path.
|
||||
fn new(path: String) -> PyResult<Self> {
|
||||
let model: ModelProto = candle_onnx::read_file(path).map_err(wrap_err)?;
|
||||
Ok(PyONNXModel(model))
|
||||
}
|
||||
|
||||
#[getter]
|
||||
/// The version of the IR this model targets.
|
||||
/// &RETURNS&: int
|
||||
fn ir_version(&self) -> i64 {
|
||||
self.0.ir_version
|
||||
}
|
||||
|
||||
#[getter]
|
||||
/// The producer of the model.
|
||||
/// &RETURNS&: str
|
||||
fn producer_name(&self) -> String {
|
||||
self.0.producer_name.clone()
|
||||
}
|
||||
|
||||
#[getter]
|
||||
/// The version of the producer of the model.
|
||||
/// &RETURNS&: str
|
||||
fn producer_version(&self) -> String {
|
||||
self.0.producer_version.clone()
|
||||
}
|
||||
|
||||
#[getter]
|
||||
/// The domain of the operator set of the model.
|
||||
/// &RETURNS&: str
|
||||
fn domain(&self) -> String {
|
||||
self.0.domain.clone()
|
||||
}
|
||||
|
||||
#[getter]
|
||||
/// The version of the model.
|
||||
/// &RETURNS&: int
|
||||
fn model_version(&self) -> i64 {
|
||||
self.0.model_version
|
||||
}
|
||||
|
||||
#[getter]
|
||||
/// The doc string of the model.
|
||||
/// &RETURNS&: str
|
||||
fn doc_string(&self) -> String {
|
||||
self.0.doc_string.clone()
|
||||
}
|
||||
|
||||
/// Get the weights of the model.
|
||||
/// &RETURNS&: Dict[str, Tensor]
|
||||
fn initializers(&self) -> PyResult<HashMap<String, PyTensor>> {
|
||||
let mut map = HashMap::new();
|
||||
if let Some(graph) = self.0.graph.as_ref() {
|
||||
for tensor_description in graph.initializer.iter() {
|
||||
let tensor = get_tensor(tensor_description, tensor_description.name.as_str())
|
||||
.map_err(wrap_err)?;
|
||||
map.insert(tensor_description.name.to_string(), PyTensor(tensor));
|
||||
}
|
||||
}
|
||||
Ok(map)
|
||||
}
|
||||
|
||||
#[getter]
|
||||
/// The inputs of the model.
|
||||
/// &RETURNS&: Optional[Dict[str, ONNXTensorDescription]]
|
||||
fn inputs(&self) -> Option<HashMap<String, PyONNXTensorDescriptor>> {
|
||||
if let Some(graph) = self.0.graph.as_ref() {
|
||||
return Some(extract_tensor_descriptions(&graph.input));
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
#[getter]
|
||||
/// The outputs of the model.
|
||||
/// &RETURNS&: Optional[Dict[str, ONNXTensorDescription]]
|
||||
fn outputs(&self) -> Option<HashMap<String, PyONNXTensorDescriptor>> {
|
||||
if let Some(graph) = self.0.graph.as_ref() {
|
||||
return Some(extract_tensor_descriptions(&graph.output));
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
#[pyo3(text_signature = "(self, inputs:Dict[str,Tensor])")]
|
||||
/// Run the model on the given inputs.
|
||||
/// &RETURNS&: Dict[str,Tensor]
|
||||
fn run(&self, inputs: HashMap<String, PyTensor>) -> PyResult<HashMap<String, PyTensor>> {
|
||||
let unwrapped_tensors = inputs.into_iter().map(|(k, v)| (k.clone(), v.0)).collect();
|
||||
|
||||
let result = simple_eval(&self.0, unwrapped_tensors).map_err(wrap_err)?;
|
||||
|
||||
Ok(result
|
||||
.into_iter()
|
||||
.map(|(k, v)| (k.clone(), PyTensor(v)))
|
||||
.collect())
|
||||
}
|
||||
}
|
|
@ -0,0 +1,6 @@
|
|||
use pyo3::exceptions::PyValueError;
|
||||
use pyo3::prelude::*;
|
||||
|
||||
pub fn wrap_err(err: ::candle::Error) -> PyErr {
|
||||
PyErr::new::<PyValueError, _>(format!("{err:?}"))
|
||||
}
|
Loading…
Reference in New Issue