Update for pyo3 0.21. (#1985)
* Update for pyo3 0.21. * Also adapt the RL example. * Fix for the pyo3-onnx bindings... * Print details on failures. * Revert pyi.
This commit is contained in:
parent
5522bbc57c
commit
b20acd622c
|
@ -25,7 +25,7 @@ hf-hub = { workspace = true, features = ["tokio"] }
|
||||||
image = { workspace = true }
|
image = { workspace = true }
|
||||||
intel-mkl-src = { workspace = true, optional = true }
|
intel-mkl-src = { workspace = true, optional = true }
|
||||||
num-traits = { workspace = true }
|
num-traits = { workspace = true }
|
||||||
pyo3 = { version = "0.20.0", features = ["auto-initialize"], optional = true }
|
pyo3 = { version = "0.21.0", features = ["auto-initialize"], optional = true }
|
||||||
rayon = { workspace = true }
|
rayon = { workspace = true }
|
||||||
rubato = { version = "0.15.0", optional = true }
|
rubato = { version = "0.15.0", optional = true }
|
||||||
safetensors = { workspace = true }
|
safetensors = { workspace = true }
|
||||||
|
|
|
@ -42,7 +42,7 @@ impl GymEnv {
|
||||||
/// Creates a new session of the specified OpenAI Gym environment.
|
/// Creates a new session of the specified OpenAI Gym environment.
|
||||||
pub fn new(name: &str) -> Result<GymEnv> {
|
pub fn new(name: &str) -> Result<GymEnv> {
|
||||||
Python::with_gil(|py| {
|
Python::with_gil(|py| {
|
||||||
let gym = py.import("gymnasium")?;
|
let gym = py.import_bound("gymnasium")?;
|
||||||
let make = gym.getattr("make")?;
|
let make = gym.getattr("make")?;
|
||||||
let env = make.call1((name,))?;
|
let env = make.call1((name,))?;
|
||||||
let action_space = env.getattr("action_space")?;
|
let action_space = env.getattr("action_space")?;
|
||||||
|
@ -66,10 +66,10 @@ impl GymEnv {
|
||||||
/// Resets the environment, returning the observation tensor.
|
/// Resets the environment, returning the observation tensor.
|
||||||
pub fn reset(&self, seed: u64) -> Result<Tensor> {
|
pub fn reset(&self, seed: u64) -> Result<Tensor> {
|
||||||
let state: Vec<f32> = Python::with_gil(|py| {
|
let state: Vec<f32> = Python::with_gil(|py| {
|
||||||
let kwargs = PyDict::new(py);
|
let kwargs = PyDict::new_bound(py);
|
||||||
kwargs.set_item("seed", seed)?;
|
kwargs.set_item("seed", seed)?;
|
||||||
let state = self.env.call_method(py, "reset", (), Some(kwargs))?;
|
let state = self.env.call_method_bound(py, "reset", (), Some(&kwargs))?;
|
||||||
state.as_ref(py).get_item(0)?.extract()
|
state.bind(py).get_item(0)?.extract()
|
||||||
})
|
})
|
||||||
.map_err(w)?;
|
.map_err(w)?;
|
||||||
Tensor::new(state, &Device::Cpu)
|
Tensor::new(state, &Device::Cpu)
|
||||||
|
@ -81,8 +81,10 @@ impl GymEnv {
|
||||||
action: A,
|
action: A,
|
||||||
) -> Result<Step<A>> {
|
) -> Result<Step<A>> {
|
||||||
let (state, reward, terminated, truncated) = Python::with_gil(|py| {
|
let (state, reward, terminated, truncated) = Python::with_gil(|py| {
|
||||||
let step = self.env.call_method(py, "step", (action.clone(),), None)?;
|
let step = self
|
||||||
let step = step.as_ref(py);
|
.env
|
||||||
|
.call_method_bound(py, "step", (action.clone(),), None)?;
|
||||||
|
let step = step.bind(py);
|
||||||
let state: Vec<f32> = step.get_item(0)?.extract()?;
|
let state: Vec<f32> = step.get_item(0)?.extract()?;
|
||||||
let reward: f64 = step.get_item(1)?.extract()?;
|
let reward: f64 = step.get_item(1)?.extract()?;
|
||||||
let terminated: bool = step.get_item(2)?.extract()?;
|
let terminated: bool = step.get_item(2)?.extract()?;
|
||||||
|
|
|
@ -24,13 +24,13 @@ fn w(res: PyErr) -> candle::Error {
|
||||||
impl VecGymEnv {
|
impl VecGymEnv {
|
||||||
pub fn new(name: &str, img_dir: Option<&str>, nprocesses: usize) -> Result<VecGymEnv> {
|
pub fn new(name: &str, img_dir: Option<&str>, nprocesses: usize) -> Result<VecGymEnv> {
|
||||||
Python::with_gil(|py| {
|
Python::with_gil(|py| {
|
||||||
let sys = py.import("sys")?;
|
let sys = py.import_bound("sys")?;
|
||||||
let path = sys.getattr("path")?;
|
let path = sys.getattr("path")?;
|
||||||
let _ = path.call_method1(
|
let _ = path.call_method1(
|
||||||
"append",
|
"append",
|
||||||
("candle-examples/examples/reinforcement-learning",),
|
("candle-examples/examples/reinforcement-learning",),
|
||||||
)?;
|
)?;
|
||||||
let gym = py.import("atari_wrappers")?;
|
let gym = py.import_bound("atari_wrappers")?;
|
||||||
let make = gym.getattr("make")?;
|
let make = gym.getattr("make")?;
|
||||||
let env = make.call1((name, img_dir, nprocesses))?;
|
let env = make.call1((name, img_dir, nprocesses))?;
|
||||||
let action_space = env.getattr("action_space")?;
|
let action_space = env.getattr("action_space")?;
|
||||||
|
@ -60,10 +60,10 @@ impl VecGymEnv {
|
||||||
|
|
||||||
pub fn step(&self, action: Vec<usize>) -> Result<Step> {
|
pub fn step(&self, action: Vec<usize>) -> Result<Step> {
|
||||||
let (obs, reward, is_done) = Python::with_gil(|py| {
|
let (obs, reward, is_done) = Python::with_gil(|py| {
|
||||||
let step = self.env.call_method(py, "step", (action,), None)?;
|
let step = self.env.call_method_bound(py, "step", (action,), None)?;
|
||||||
let step = step.as_ref(py);
|
let step = step.bind(py);
|
||||||
let obs = step.get_item(0)?.call_method("flatten", (), None)?;
|
let obs = step.get_item(0)?.call_method("flatten", (), None)?;
|
||||||
let obs_buffer = pyo3::buffer::PyBuffer::get(obs)?;
|
let obs_buffer = pyo3::buffer::PyBuffer::get_bound(&obs)?;
|
||||||
let obs: Vec<u8> = obs_buffer.to_vec(py)?;
|
let obs: Vec<u8> = obs_buffer.to_vec(py)?;
|
||||||
let reward: Vec<f32> = step.get_item(1)?.extract()?;
|
let reward: Vec<f32> = step.get_item(1)?.extract()?;
|
||||||
let is_done: Vec<f32> = step.get_item(2)?.extract()?;
|
let is_done: Vec<f32> = step.get_item(2)?.extract()?;
|
||||||
|
|
|
@ -20,10 +20,10 @@ candle-nn = { workspace = true }
|
||||||
candle-onnx = { workspace = true, optional = true }
|
candle-onnx = { workspace = true, optional = true }
|
||||||
half = { workspace = true }
|
half = { workspace = true }
|
||||||
intel-mkl-src = { workspace = true, optional = true }
|
intel-mkl-src = { workspace = true, optional = true }
|
||||||
pyo3 = { version = "0.20.0", features = ["extension-module", "abi3-py38"] }
|
pyo3 = { version = "0.21.0", features = ["extension-module", "abi3-py38"] }
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
pyo3-build-config = "0.20"
|
pyo3-build-config = "0.21"
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = []
|
default = []
|
||||||
|
|
|
@ -0,0 +1,19 @@
|
||||||
|
# Generated content DO NOT EDIT
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
|
||||||
|
from os import PathLike
|
||||||
|
from candle.typing import _ArrayLike, Device, Scalar, Index, Shape
|
||||||
|
from candle import Tensor, DType, QTensor
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def silu(tensor: Tensor) -> Tensor:
|
||||||
|
"""
|
||||||
|
Applies the Sigmoid Linear Unit (SiLU) function to a given tensor.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def softmax(tensor: Tensor, dim: int) -> Tensor:
|
||||||
|
"""
|
||||||
|
Applies the Softmax function to a given tensor.#
|
||||||
|
"""
|
||||||
|
pass
|
|
@ -60,8 +60,8 @@ impl PyDType {
|
||||||
impl PyDType {
|
impl PyDType {
|
||||||
fn from_pyobject(ob: PyObject, py: Python<'_>) -> PyResult<Self> {
|
fn from_pyobject(ob: PyObject, py: Python<'_>) -> PyResult<Self> {
|
||||||
use std::str::FromStr;
|
use std::str::FromStr;
|
||||||
if let Ok(dtype) = ob.extract::<&str>(py) {
|
if let Ok(dtype) = ob.extract::<String>(py) {
|
||||||
let dtype = DType::from_str(dtype)
|
let dtype = DType::from_str(&dtype)
|
||||||
.map_err(|_| PyTypeError::new_err(format!("invalid dtype '{dtype}'")))?;
|
.map_err(|_| PyTypeError::new_err(format!("invalid dtype '{dtype}'")))?;
|
||||||
Ok(Self(dtype))
|
Ok(Self(dtype))
|
||||||
} else {
|
} else {
|
||||||
|
@ -116,8 +116,8 @@ impl PyDevice {
|
||||||
|
|
||||||
impl<'source> FromPyObject<'source> for PyDevice {
|
impl<'source> FromPyObject<'source> for PyDevice {
|
||||||
fn extract(ob: &'source PyAny) -> PyResult<Self> {
|
fn extract(ob: &'source PyAny) -> PyResult<Self> {
|
||||||
let device: &str = ob.extract()?;
|
let device: String = ob.extract()?;
|
||||||
let device = match device {
|
let device = match device.as_str() {
|
||||||
"cpu" => PyDevice::Cpu,
|
"cpu" => PyDevice::Cpu,
|
||||||
"cuda" => PyDevice::Cuda,
|
"cuda" => PyDevice::Cuda,
|
||||||
_ => Err(PyTypeError::new_err(format!("invalid device '{device}'")))?,
|
_ => Err(PyTypeError::new_err(format!("invalid device '{device}'")))?,
|
||||||
|
@ -265,7 +265,7 @@ impl PyTensor {
|
||||||
} else if let Ok(TorchTensor(numpy)) = data.extract::<TorchTensor>(py) {
|
} else if let Ok(TorchTensor(numpy)) = data.extract::<TorchTensor>(py) {
|
||||||
return PyTensor::new(py, numpy);
|
return PyTensor::new(py, numpy);
|
||||||
} else {
|
} else {
|
||||||
let ty = data.as_ref(py).get_type();
|
let ty = data.bind(py).get_type();
|
||||||
Err(PyTypeError::new_err(format!(
|
Err(PyTypeError::new_err(format!(
|
||||||
"incorrect type {ty} for tensor"
|
"incorrect type {ty} for tensor"
|
||||||
)))?
|
)))?
|
||||||
|
@ -322,7 +322,7 @@ impl PyTensor {
|
||||||
fn to_torch(&self, py: Python<'_>) -> PyResult<PyObject> {
|
fn to_torch(&self, py: Python<'_>) -> PyResult<PyObject> {
|
||||||
let candle_values = self.values(py)?;
|
let candle_values = self.values(py)?;
|
||||||
let torch_tensor: PyObject = py
|
let torch_tensor: PyObject = py
|
||||||
.import("torch")?
|
.import_bound("torch")?
|
||||||
.getattr("tensor")?
|
.getattr("tensor")?
|
||||||
.call1((candle_values,))?
|
.call1((candle_values,))?
|
||||||
.extract()?;
|
.extract()?;
|
||||||
|
@ -333,7 +333,7 @@ impl PyTensor {
|
||||||
/// Gets the tensor's shape.
|
/// Gets the tensor's shape.
|
||||||
/// &RETURNS&: Tuple[int]
|
/// &RETURNS&: Tuple[int]
|
||||||
fn shape(&self, py: Python<'_>) -> PyObject {
|
fn shape(&self, py: Python<'_>) -> PyObject {
|
||||||
PyTuple::new(py, self.0.dims()).to_object(py)
|
PyTuple::new_bound(py, self.0.dims()).to_object(py)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[getter]
|
#[getter]
|
||||||
|
@ -347,7 +347,7 @@ impl PyTensor {
|
||||||
/// Gets the tensor's strides.
|
/// Gets the tensor's strides.
|
||||||
/// &RETURNS&: Tuple[int]
|
/// &RETURNS&: Tuple[int]
|
||||||
fn stride(&self, py: Python<'_>) -> PyObject {
|
fn stride(&self, py: Python<'_>) -> PyObject {
|
||||||
PyTuple::new(py, self.0.stride()).to_object(py)
|
PyTuple::new_bound(py, self.0.stride()).to_object(py)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[getter]
|
#[getter]
|
||||||
|
@ -527,7 +527,7 @@ impl PyTensor {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn extract_indexer(
|
fn extract_indexer(
|
||||||
py_indexer: &PyAny,
|
py_indexer: &Bound<PyAny>,
|
||||||
current_dim: usize,
|
current_dim: usize,
|
||||||
dims: &[usize],
|
dims: &[usize],
|
||||||
index_argument_count: usize,
|
index_argument_count: usize,
|
||||||
|
@ -567,7 +567,7 @@ impl PyTensor {
|
||||||
),
|
),
|
||||||
current_dim + 1,
|
current_dim + 1,
|
||||||
))
|
))
|
||||||
} else if py_indexer.is_ellipsis() {
|
} else if py_indexer.is(&py_indexer.py().Ellipsis()) {
|
||||||
// Handle '...' e.g. tensor[..., 0]
|
// Handle '...' e.g. tensor[..., 0]
|
||||||
if current_dim > 0 {
|
if current_dim > 0 {
|
||||||
return Err(PyTypeError::new_err(
|
return Err(PyTypeError::new_err(
|
||||||
|
@ -586,7 +586,7 @@ impl PyTensor {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Ok(tuple) = idx.downcast::<pyo3::types::PyTuple>(py) {
|
if let Ok(tuple) = idx.downcast_bound::<pyo3::types::PyTuple>(py) {
|
||||||
let not_none_count: usize = tuple.iter().filter(|x| !x.is_none()).count();
|
let not_none_count: usize = tuple.iter().filter(|x| !x.is_none()).count();
|
||||||
|
|
||||||
if not_none_count > dims.len() {
|
if not_none_count > dims.len() {
|
||||||
|
@ -596,12 +596,12 @@ impl PyTensor {
|
||||||
let mut current_dim = 0;
|
let mut current_dim = 0;
|
||||||
for item in tuple.iter() {
|
for item in tuple.iter() {
|
||||||
let (indexer, new_current_dim) =
|
let (indexer, new_current_dim) =
|
||||||
extract_indexer(item, current_dim, dims, not_none_count)?;
|
extract_indexer(&item, current_dim, dims, not_none_count)?;
|
||||||
current_dim = new_current_dim;
|
current_dim = new_current_dim;
|
||||||
indexers.push(indexer);
|
indexers.push(indexer);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
let (indexer, _) = extract_indexer(idx.downcast::<PyAny>(py)?, 0, dims, 1)?;
|
let (indexer, _) = extract_indexer(idx.downcast_bound::<PyAny>(py)?, 0, dims, 1)?;
|
||||||
indexers.push(indexer);
|
indexers.push(indexer);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -652,7 +652,7 @@ impl PyTensor {
|
||||||
|
|
||||||
/// Add two tensors.
|
/// Add two tensors.
|
||||||
/// &RETURNS&: Tensor
|
/// &RETURNS&: Tensor
|
||||||
fn __add__(&self, rhs: &PyAny) -> PyResult<Self> {
|
fn __add__(&self, rhs: &Bound<PyAny>) -> PyResult<Self> {
|
||||||
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
|
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
|
||||||
self.0.broadcast_add(&rhs.0).map_err(wrap_err)?
|
self.0.broadcast_add(&rhs.0).map_err(wrap_err)?
|
||||||
} else if let Ok(rhs) = rhs.extract::<f64>() {
|
} else if let Ok(rhs) = rhs.extract::<f64>() {
|
||||||
|
@ -663,13 +663,13 @@ impl PyTensor {
|
||||||
Ok(Self(tensor))
|
Ok(Self(tensor))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn __radd__(&self, rhs: &PyAny) -> PyResult<Self> {
|
fn __radd__(&self, rhs: &Bound<PyAny>) -> PyResult<Self> {
|
||||||
self.__add__(rhs)
|
self.__add__(rhs)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Multiply two tensors.
|
/// Multiply two tensors.
|
||||||
/// &RETURNS&: Tensor
|
/// &RETURNS&: Tensor
|
||||||
fn __mul__(&self, rhs: &PyAny) -> PyResult<Self> {
|
fn __mul__(&self, rhs: &Bound<PyAny>) -> PyResult<Self> {
|
||||||
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
|
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
|
||||||
self.0.broadcast_mul(&rhs.0).map_err(wrap_err)?
|
self.0.broadcast_mul(&rhs.0).map_err(wrap_err)?
|
||||||
} else if let Ok(rhs) = rhs.extract::<f64>() {
|
} else if let Ok(rhs) = rhs.extract::<f64>() {
|
||||||
|
@ -680,13 +680,13 @@ impl PyTensor {
|
||||||
Ok(Self(tensor))
|
Ok(Self(tensor))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn __rmul__(&self, rhs: &PyAny) -> PyResult<Self> {
|
fn __rmul__(&self, rhs: &Bound<PyAny>) -> PyResult<Self> {
|
||||||
self.__mul__(rhs)
|
self.__mul__(rhs)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Subtract two tensors.
|
/// Subtract two tensors.
|
||||||
/// &RETURNS&: Tensor
|
/// &RETURNS&: Tensor
|
||||||
fn __sub__(&self, rhs: &PyAny) -> PyResult<Self> {
|
fn __sub__(&self, rhs: &Bound<PyAny>) -> PyResult<Self> {
|
||||||
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
|
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
|
||||||
self.0.broadcast_sub(&rhs.0).map_err(wrap_err)?
|
self.0.broadcast_sub(&rhs.0).map_err(wrap_err)?
|
||||||
} else if let Ok(rhs) = rhs.extract::<f64>() {
|
} else if let Ok(rhs) = rhs.extract::<f64>() {
|
||||||
|
@ -699,7 +699,7 @@ impl PyTensor {
|
||||||
|
|
||||||
/// Divide two tensors.
|
/// Divide two tensors.
|
||||||
/// &RETURNS&: Tensor
|
/// &RETURNS&: Tensor
|
||||||
fn __truediv__(&self, rhs: &PyAny) -> PyResult<Self> {
|
fn __truediv__(&self, rhs: &Bound<PyAny>) -> PyResult<Self> {
|
||||||
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
|
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
|
||||||
self.0.broadcast_div(&rhs.0).map_err(wrap_err)?
|
self.0.broadcast_div(&rhs.0).map_err(wrap_err)?
|
||||||
} else if let Ok(rhs) = rhs.extract::<f64>() {
|
} else if let Ok(rhs) = rhs.extract::<f64>() {
|
||||||
|
@ -711,7 +711,7 @@ impl PyTensor {
|
||||||
}
|
}
|
||||||
/// Rich-compare two tensors.
|
/// Rich-compare two tensors.
|
||||||
/// &RETURNS&: Tensor
|
/// &RETURNS&: Tensor
|
||||||
fn __richcmp__(&self, rhs: &PyAny, op: CompareOp) -> PyResult<Self> {
|
fn __richcmp__(&self, rhs: &Bound<PyAny>, op: CompareOp) -> PyResult<Self> {
|
||||||
let compare = |lhs: &Tensor, rhs: &Tensor| {
|
let compare = |lhs: &Tensor, rhs: &Tensor| {
|
||||||
let t = match op {
|
let t = match op {
|
||||||
CompareOp::Eq => lhs.eq(rhs),
|
CompareOp::Eq => lhs.eq(rhs),
|
||||||
|
@ -957,7 +957,7 @@ impl PyTensor {
|
||||||
#[pyo3(signature = (*args, **kwargs), text_signature = "(self, *args, **kwargs)")]
|
#[pyo3(signature = (*args, **kwargs), text_signature = "(self, *args, **kwargs)")]
|
||||||
/// Performs Tensor dtype and/or device conversion.
|
/// Performs Tensor dtype and/or device conversion.
|
||||||
/// &RETURNS&: Tensor
|
/// &RETURNS&: Tensor
|
||||||
fn to(&self, args: &PyTuple, kwargs: Option<&PyDict>) -> PyResult<Self> {
|
fn to(&self, args: &Bound<PyTuple>, kwargs: Option<&Bound<PyDict>>) -> PyResult<Self> {
|
||||||
let mut device: Option<PyDevice> = None;
|
let mut device: Option<PyDevice> = None;
|
||||||
let mut dtype: Option<PyDType> = None;
|
let mut dtype: Option<PyDType> = None;
|
||||||
let mut other: Option<PyTensor> = None;
|
let mut other: Option<PyTensor> = None;
|
||||||
|
@ -1227,7 +1227,7 @@ impl PyQTensor {
|
||||||
///Gets the shape of the tensor.
|
///Gets the shape of the tensor.
|
||||||
/// &RETURNS&: Tuple[int]
|
/// &RETURNS&: Tuple[int]
|
||||||
fn shape(&self, py: Python<'_>) -> PyObject {
|
fn shape(&self, py: Python<'_>) -> PyObject {
|
||||||
PyTuple::new(py, self.0.shape().dims()).to_object(py)
|
PyTuple::new_bound(py, self.0.shape().dims()).to_object(py)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn __repr__(&self) -> String {
|
fn __repr__(&self) -> String {
|
||||||
|
@ -1265,7 +1265,7 @@ fn load_safetensors(path: &str, py: Python<'_>) -> PyResult<PyObject> {
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|(key, value)| (key, PyTensor(value).into_py(py)))
|
.map(|(key, value)| (key, PyTensor(value).into_py(py)))
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
Ok(res.into_py_dict(py).to_object(py))
|
Ok(res.into_py_dict_bound(py).to_object(py))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[pyfunction]
|
#[pyfunction]
|
||||||
|
@ -1303,7 +1303,7 @@ fn load_ggml(
|
||||||
.map(|(key, qtensor)| Ok((key, PyQTensor(Arc::new(qtensor)).into_py(py))))
|
.map(|(key, qtensor)| Ok((key, PyQTensor(Arc::new(qtensor)).into_py(py))))
|
||||||
.collect::<::candle::Result<Vec<_>>>()
|
.collect::<::candle::Result<Vec<_>>>()
|
||||||
.map_err(wrap_err)?;
|
.map_err(wrap_err)?;
|
||||||
let tensors = tensors.into_py_dict(py).to_object(py);
|
let tensors = tensors.into_py_dict_bound(py).to_object(py);
|
||||||
let hparams = [
|
let hparams = [
|
||||||
("n_vocab", ggml.hparams.n_vocab),
|
("n_vocab", ggml.hparams.n_vocab),
|
||||||
("n_embd", ggml.hparams.n_embd),
|
("n_embd", ggml.hparams.n_embd),
|
||||||
|
@ -1313,7 +1313,7 @@ fn load_ggml(
|
||||||
("n_rot", ggml.hparams.n_rot),
|
("n_rot", ggml.hparams.n_rot),
|
||||||
("ftype", ggml.hparams.ftype),
|
("ftype", ggml.hparams.ftype),
|
||||||
];
|
];
|
||||||
let hparams = hparams.into_py_dict(py).to_object(py);
|
let hparams = hparams.into_py_dict_bound(py).to_object(py);
|
||||||
let vocab = ggml
|
let vocab = ggml
|
||||||
.vocab
|
.vocab
|
||||||
.token_score_pairs
|
.token_score_pairs
|
||||||
|
@ -1351,7 +1351,7 @@ fn load_gguf(
|
||||||
gguf_file::Value::Bool(x) => x.into_py(py),
|
gguf_file::Value::Bool(x) => x.into_py(py),
|
||||||
gguf_file::Value::String(x) => x.into_py(py),
|
gguf_file::Value::String(x) => x.into_py(py),
|
||||||
gguf_file::Value::Array(x) => {
|
gguf_file::Value::Array(x) => {
|
||||||
let list = pyo3::types::PyList::empty(py);
|
let list = pyo3::types::PyList::empty_bound(py);
|
||||||
for elem in x.iter() {
|
for elem in x.iter() {
|
||||||
list.append(gguf_value_to_pyobject(elem, py)?)?;
|
list.append(gguf_value_to_pyobject(elem, py)?)?;
|
||||||
}
|
}
|
||||||
|
@ -1371,13 +1371,13 @@ fn load_gguf(
|
||||||
})
|
})
|
||||||
.collect::<::candle::Result<Vec<_>>>()
|
.collect::<::candle::Result<Vec<_>>>()
|
||||||
.map_err(wrap_err)?;
|
.map_err(wrap_err)?;
|
||||||
let tensors = tensors.into_py_dict(py).to_object(py);
|
let tensors = tensors.into_py_dict_bound(py).to_object(py);
|
||||||
let metadata = gguf
|
let metadata = gguf
|
||||||
.metadata
|
.metadata
|
||||||
.iter()
|
.iter()
|
||||||
.map(|(key, value)| Ok((key, gguf_value_to_pyobject(value, py)?)))
|
.map(|(key, value)| Ok((key, gguf_value_to_pyobject(value, py)?)))
|
||||||
.collect::<PyResult<Vec<_>>>()?
|
.collect::<PyResult<Vec<_>>>()?
|
||||||
.into_py_dict(py)
|
.into_py_dict_bound(py)
|
||||||
.to_object(py);
|
.to_object(py);
|
||||||
Ok((tensors, metadata))
|
Ok((tensors, metadata))
|
||||||
}
|
}
|
||||||
|
@ -1390,7 +1390,7 @@ fn load_gguf(
|
||||||
fn save_gguf(path: &str, tensors: PyObject, metadata: PyObject, py: Python<'_>) -> PyResult<()> {
|
fn save_gguf(path: &str, tensors: PyObject, metadata: PyObject, py: Python<'_>) -> PyResult<()> {
|
||||||
use ::candle::quantized::gguf_file;
|
use ::candle::quantized::gguf_file;
|
||||||
|
|
||||||
fn pyobject_to_gguf_value(v: &PyAny, py: Python<'_>) -> PyResult<gguf_file::Value> {
|
fn pyobject_to_gguf_value(v: &Bound<PyAny>, py: Python<'_>) -> PyResult<gguf_file::Value> {
|
||||||
let v: gguf_file::Value = if let Ok(x) = v.extract::<u8>() {
|
let v: gguf_file::Value = if let Ok(x) = v.extract::<u8>() {
|
||||||
gguf_file::Value::U8(x)
|
gguf_file::Value::U8(x)
|
||||||
} else if let Ok(x) = v.extract::<i8>() {
|
} else if let Ok(x) = v.extract::<i8>() {
|
||||||
|
@ -1418,7 +1418,7 @@ fn save_gguf(path: &str, tensors: PyObject, metadata: PyObject, py: Python<'_>)
|
||||||
} else if let Ok(x) = v.extract::<Vec<PyObject>>() {
|
} else if let Ok(x) = v.extract::<Vec<PyObject>>() {
|
||||||
let x = x
|
let x = x
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|f| pyobject_to_gguf_value(f.as_ref(py), py))
|
.map(|f| pyobject_to_gguf_value(f.bind(py), py))
|
||||||
.collect::<PyResult<Vec<_>>>()?;
|
.collect::<PyResult<Vec<_>>>()?;
|
||||||
gguf_file::Value::Array(x)
|
gguf_file::Value::Array(x)
|
||||||
} else {
|
} else {
|
||||||
|
@ -1450,7 +1450,7 @@ fn save_gguf(path: &str, tensors: PyObject, metadata: PyObject, py: Python<'_>)
|
||||||
Ok((
|
Ok((
|
||||||
key.extract::<String>()
|
key.extract::<String>()
|
||||||
.map_err(|_| PyErr::new::<PyValueError, _>("keys must be strings"))?,
|
.map_err(|_| PyErr::new::<PyValueError, _>("keys must be strings"))?,
|
||||||
pyobject_to_gguf_value(value, py)?,
|
pyobject_to_gguf_value(&value.as_borrowed(), py)?,
|
||||||
))
|
))
|
||||||
})
|
})
|
||||||
.collect::<PyResult<Vec<_>>>()?;
|
.collect::<PyResult<Vec<_>>>()?;
|
||||||
|
@ -1498,7 +1498,7 @@ fn get_num_threads() -> usize {
|
||||||
::candle::utils::get_num_threads()
|
::candle::utils::get_num_threads()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn candle_utils(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
fn candle_utils(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||||
m.add_function(wrap_pyfunction!(cuda_is_available, m)?)?;
|
m.add_function(wrap_pyfunction!(cuda_is_available, m)?)?;
|
||||||
m.add_function(wrap_pyfunction!(get_num_threads, m)?)?;
|
m.add_function(wrap_pyfunction!(get_num_threads, m)?)?;
|
||||||
m.add_function(wrap_pyfunction!(has_accelerate, m)?)?;
|
m.add_function(wrap_pyfunction!(has_accelerate, m)?)?;
|
||||||
|
@ -1579,7 +1579,7 @@ fn tanh(tensor: PyTensor) -> PyResult<PyTensor> {
|
||||||
Ok(PyTensor(s))
|
Ok(PyTensor(s))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn candle_functional_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
fn candle_functional_m(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||||
m.add_function(wrap_pyfunction!(silu, m)?)?;
|
m.add_function(wrap_pyfunction!(silu, m)?)?;
|
||||||
m.add_function(wrap_pyfunction!(softmax, m)?)?;
|
m.add_function(wrap_pyfunction!(softmax, m)?)?;
|
||||||
m.add_function(wrap_pyfunction!(max_pool2d, m)?)?;
|
m.add_function(wrap_pyfunction!(max_pool2d, m)?)?;
|
||||||
|
@ -1591,7 +1591,7 @@ fn candle_functional_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "onnx")]
|
#[cfg(feature = "onnx")]
|
||||||
fn candle_onnx_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
fn candle_onnx_m(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||||
use onnx::{PyONNXModel, PyONNXTensorDescriptor};
|
use onnx::{PyONNXModel, PyONNXTensorDescriptor};
|
||||||
m.add_class::<PyONNXModel>()?;
|
m.add_class::<PyONNXModel>()?;
|
||||||
m.add_class::<PyONNXTensorDescriptor>()?;
|
m.add_class::<PyONNXTensorDescriptor>()?;
|
||||||
|
@ -1599,18 +1599,18 @@ fn candle_onnx_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[pymodule]
|
#[pymodule]
|
||||||
fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
fn candle(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||||
let utils = PyModule::new(py, "utils")?;
|
let utils = PyModule::new_bound(py, "utils")?;
|
||||||
candle_utils(py, utils)?;
|
candle_utils(py, &utils)?;
|
||||||
m.add_submodule(utils)?;
|
m.add_submodule(&utils)?;
|
||||||
let nn = PyModule::new(py, "functional")?;
|
let nn = PyModule::new_bound(py, "functional")?;
|
||||||
candle_functional_m(py, nn)?;
|
candle_functional_m(py, &nn)?;
|
||||||
m.add_submodule(nn)?;
|
m.add_submodule(&nn)?;
|
||||||
#[cfg(feature = "onnx")]
|
#[cfg(feature = "onnx")]
|
||||||
{
|
{
|
||||||
let onnx = PyModule::new(py, "onnx")?;
|
let onnx = PyModule::new_bound(py, "onnx")?;
|
||||||
candle_onnx_m(py, onnx)?;
|
candle_onnx_m(py, &onnx)?;
|
||||||
m.add_submodule(onnx)?;
|
m.add_submodule(&onnx)?;
|
||||||
}
|
}
|
||||||
m.add_class::<PyTensor>()?;
|
m.add_class::<PyTensor>()?;
|
||||||
m.add_class::<PyQTensor>()?;
|
m.add_class::<PyQTensor>()?;
|
||||||
|
|
|
@ -39,7 +39,7 @@ impl PyONNXTensorDescriptor {
|
||||||
/// The shape of the tensor.
|
/// The shape of the tensor.
|
||||||
/// &RETURNS&: Tuple[Union[int,str,Any]]
|
/// &RETURNS&: Tuple[Union[int,str,Any]]
|
||||||
fn shape(&self, py: Python) -> PyResult<Py<PyTuple>> {
|
fn shape(&self, py: Python) -> PyResult<Py<PyTuple>> {
|
||||||
let shape = PyList::empty(py);
|
let shape = PyList::empty_bound(py);
|
||||||
if let Some(d) = &self.0.shape {
|
if let Some(d) = &self.0.shape {
|
||||||
for dim in d.dim.iter() {
|
for dim in d.dim.iter() {
|
||||||
if let Some(value) = &dim.value {
|
if let Some(value) = &dim.value {
|
||||||
|
|
|
@ -206,6 +206,8 @@ def write(module, directory, origin, check=False):
|
||||||
if check:
|
if check:
|
||||||
with open(filename, "r") as f:
|
with open(filename, "r") as f:
|
||||||
data = f.read()
|
data = f.read()
|
||||||
|
print("generated content")
|
||||||
|
print(pyi_content)
|
||||||
assert data == pyi_content, f"The content of {filename} seems outdated, please run `python stub.py`"
|
assert data == pyi_content, f"The content of {filename} seems outdated, please run `python stub.py`"
|
||||||
else:
|
else:
|
||||||
with open(filename, "w") as f:
|
with open(filename, "w") as f:
|
||||||
|
@ -229,6 +231,8 @@ def write(module, directory, origin, check=False):
|
||||||
if check:
|
if check:
|
||||||
with open(filename, "r") as f:
|
with open(filename, "r") as f:
|
||||||
data = f.read()
|
data = f.read()
|
||||||
|
print("generated content")
|
||||||
|
print(py_content)
|
||||||
assert data == py_content, f"The content of {filename} seems outdated, please run `python stub.py`"
|
assert data == py_content, f"The content of {filename} seems outdated, please run `python stub.py`"
|
||||||
else:
|
else:
|
||||||
with open(filename, "w") as f:
|
with open(filename, "w") as f:
|
||||||
|
|
Loading…
Reference in New Issue