Generate `*.pyi` stubs for PyO3 wrapper (#870)

* Begin to generate typehints.

* generate correct stubs

* Correctly include stubs

* Add comments and typhints to static functions

* ensure candle-pyo3 directory

* Make `llama.rope.freq_base` optional

* `fmt`
This commit is contained in:
Lukas Kreussel 2023-09-16 18:23:38 +02:00 committed by GitHub
parent 7cafca835a
commit 8658df3485
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 857 additions and 40 deletions

160
candle-pyo3/.gitignore vendored Normal file
View File

@ -0,0 +1,160 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

View File

@ -12,7 +12,6 @@ readme = "README.md"
[lib]
name = "candle"
crate-type = ["cdylib"]
doc = false
[dependencies]
candle = { path = "../candle-core", version = "0.2.2", package = "candle-core" }

View File

@ -1,7 +1,26 @@
## Installation
From the `candle-pyo3` directory, enable a virtual env where you will want the
candle package to be installed then run.
```bash
maturin develop
maturin develop -r
python test.py
```
## Generating Stub Files for Type Hinting
For type hinting support, the `candle-pyo3` package requires `*.pyi` files. You can automatically generate these files using the `stub.py` script.
### Steps:
1. Install the package using `maturin`.
2. Generate the stub files by running:
```
python stub.py
```
### Validation:
To ensure that the stub files match the current implementation, execute:
```
python stub.py --check
```

View File

@ -0,0 +1 @@
from .candle import *

View File

@ -0,0 +1,248 @@
# Generated content DO NOT EDIT
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
from os import PathLike
from candle.typing import _ArrayLike, Device
class bf16(DType):
pass
@staticmethod
def cat(tensors: List[Tensor], dim: int):
"""
Concatenate the tensors across one axis.
"""
pass
class f16(DType):
pass
class f32(DType):
pass
class f64(DType):
pass
class i64(DType):
pass
@staticmethod
def ones(shape: Sequence[int], dtype: Optional[DType] = None, device: Optional[Device] = None):
""" """
pass
@staticmethod
def rand(shape: Sequence[int], device: Optional[Device] = None):
"""
Creates a new tensor with random values.
"""
pass
@staticmethod
def randn(shape: Sequence[int], device: Optional[Device] = None):
""" """
pass
@staticmethod
def stack(tensors: List[Tensor], dim: int):
"""
Stack the tensors along a new axis.
"""
pass
@staticmethod
def tensor(data: _ArrayLike):
"""
Creates a new tensor from a Python value. The value can be a scalar or array-like object.
"""
pass
class u32(DType):
pass
class u8(DType):
pass
@staticmethod
def zeros(shape: Sequence[int], dtype: Optional[DType] = None, device: Optional[Device] = None):
""" """
pass
class DType:
pass
class QTensor:
def dequantize(self):
""" """
pass
@property
def ggml_dtype(self):
""" """
pass
def matmul_t(self, lhs):
""" """
pass
@property
def rank(self):
""" """
pass
@property
def shape(self):
""" """
pass
class Tensor:
def __init__(data: _ArrayLike):
pass
def argmax_keepdim(self, dim):
""" """
pass
def argmin_keepdim(self, dim):
""" """
pass
def broadcast_add(self, rhs):
""" """
pass
def broadcast_as(self, shape):
""" """
pass
def broadcast_div(self, rhs):
""" """
pass
def broadcast_left(self, shape):
""" """
pass
def broadcast_mul(self, rhs):
""" """
pass
def broadcast_sub(self, rhs):
""" """
pass
def contiguous(self):
""" """
pass
def copy(self):
""" """
pass
def cos(self):
""" """
pass
def detach(self):
""" """
pass
@property
def device(self):
""" """
pass
@property
def dtype(self):
""" """
pass
def exp(self):
""" """
pass
def flatten_all(self):
""" """
pass
def flatten_from(self, dim):
""" """
pass
def flatten_to(self, dim):
""" """
pass
def get(self, index):
""" """
pass
def index_select(self, rhs, dim):
""" """
pass
def is_contiguous(self):
""" """
pass
def is_fortran_contiguous(self):
""" """
pass
def log(self):
""" """
pass
def matmul(self, rhs):
""" """
pass
def max_keepdim(self, dim):
""" """
pass
def mean_all(self):
""" """
pass
def min_keepdim(self, dim):
""" """
pass
def narrow(self, dim, start, len):
""" """
pass
def powf(self, p):
""" """
pass
def quantize(self, quantized_dtype):
""" """
pass
@property
def rank(self):
""" """
pass
def recip(self):
""" """
pass
def reshape(self, shape):
""" """
pass
@property
def shape(self):
"""
Gets the tensor shape as a Python tuple.
"""
pass
def sin(self):
""" """
pass
def sqr(self):
""" """
pass
def sqrt(self):
""" """
pass
def squeeze(self, dim):
""" """
pass
@property
def stride(self):
""" """
pass
def sum_all(self):
""" """
pass
def sum_keepdim(self, dims):
""" """
pass
def t(self):
""" """
pass
def to_device(self, device):
""" """
pass
def to_dtype(self, dtype):
""" """
pass
def transpose(self, dim1, dim2):
""" """
pass
def unsqueeze(self, dim):
""" """
pass
def values(self):
"""
Gets the tensor's data as a Python scalar or array-like object.
"""
pass
def where_cond(self, on_true, on_false):
""" """
pass

View File

@ -0,0 +1,5 @@
# Generated content DO NOT EDIT
from .. import nn
silu = nn.silu
softmax = nn.softmax

View File

@ -0,0 +1,19 @@
# Generated content DO NOT EDIT
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
from os import PathLike
from candle.typing import _ArrayLike, Device
from candle import Tensor, DType
@staticmethod
def silu(tensor: Tensor):
"""
Applies the Sigmoid Linear Unit (SiLU) function to a given tensor.
"""
pass
@staticmethod
def softmax(tensor: Tensor, dim: int):
"""
Applies the Softmax function to a given tensor.
"""
pass

View File

@ -0,0 +1,16 @@
from typing import TypeVar, Union, Sequence
_T = TypeVar("_T")
_ArrayLike = Union[
_T,
Sequence[_T],
Sequence[Sequence[_T]],
Sequence[Sequence[Sequence[_T]]],
Sequence[Sequence[Sequence[Sequence[_T]]]],
]
CPU:str = "cpu"
CUDA:str = "cuda"
Device = TypeVar("Device", CPU, CUDA)

View File

@ -0,0 +1,11 @@
# Generated content DO NOT EDIT
from .. import utils
cuda_is_available = utils.cuda_is_available
get_num_threads = utils.get_num_threads
has_accelerate = utils.has_accelerate
has_mkl = utils.has_mkl
load_ggml = utils.load_ggml
load_gguf = utils.load_gguf
load_safetensors = utils.load_safetensors
save_safetensors = utils.save_safetensors

View File

@ -0,0 +1,63 @@
# Generated content DO NOT EDIT
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
from os import PathLike
from candle.typing import _ArrayLike, Device
from candle import Tensor, DType
@staticmethod
def cuda_is_available():
"""
Returns true if the 'cuda' backend is available.
"""
pass
@staticmethod
def get_num_threads():
"""
Returns the number of threads used by the candle.
"""
pass
@staticmethod
def has_accelerate():
"""
Returns true if candle was compiled with 'accelerate' support.
"""
pass
@staticmethod
def has_mkl():
"""
Returns true if candle was compiled with MKL support.
"""
pass
@staticmethod
def load_ggml(path: Union[str, PathLike]):
"""
Load a GGML file. Returns a tuple of three objects: a dictionary mapping tensor names to tensors,
a dictionary mapping hyperparameter names to hyperparameter values, and a vocabulary.
"""
pass
@staticmethod
def load_gguf(path: Union[str, PathLike]):
"""
Loads a GGUF file. Returns a tuple of two dictionaries: the first maps tensor names to tensors,
and the second maps metadata keys to metadata values.
"""
pass
@staticmethod
def load_safetensors(path: Union[str, PathLike]):
"""
Loads a safetensors file. Returns a dictionary mapping tensor names to tensors.
"""
pass
@staticmethod
def save_safetensors(path: Union[str, PathLike], tensors: Dict[str, Tensor]):
"""
Saves a dictionary of tensors to a safetensors file.
"""
pass

View File

@ -0,0 +1,30 @@
[project]
name = 'candle-pyo3'
requires-python = '>=3.7'
authors = [
{name = 'Laurent Mazare', email = ''},
]
dynamic = [
'description',
'license',
'readme',
]
[project.urls]
Homepage = 'https://github.com/huggingface/candle'
Source = 'https://github.com/huggingface/candle'
[build-system]
requires = ["maturin>=1.0,<2.0"]
build-backend = "maturin"
[tool.maturin]
python-source = "py_src"
module-name = "candle.candle"
bindings = 'pyo3'
features = ["pyo3/extension-module"]
[tool.black]
line-length = 119
target-version = ['py35']

View File

@ -1,6 +1,7 @@
# This example shows how the candle Python api can be used to replicate llama.cpp.
import sys
import candle
from candle.utils import load_ggml,load_gguf
MAX_SEQ_LEN = 4096
@ -154,7 +155,7 @@ def main():
filename = sys.argv[1]
print(f"reading model file {filename}")
if filename.endswith("gguf"):
all_tensors, metadata = candle.load_gguf(sys.argv[1])
all_tensors, metadata = load_gguf(sys.argv[1])
vocab = metadata["tokenizer.ggml.tokens"]
for i, v in enumerate(vocab):
vocab[i] = '\n' if v == '<0x0A>' else v.replace('', ' ')
@ -168,13 +169,13 @@ def main():
'n_head_kv': metadata['llama.attention.head_count_kv'],
'n_layer': metadata['llama.block_count'],
'n_rot': metadata['llama.rope.dimension_count'],
'rope_freq': metadata['llama.rope.freq_base'],
'rope_freq': metadata.get('llama.rope.freq_base', 10000.),
'ftype': metadata['general.file_type'],
}
all_tensors = { gguf_rename(k): v for k, v in all_tensors.items() }
else:
all_tensors, hparams, vocab = candle.load_ggml(sys.argv[1])
all_tensors, hparams, vocab = load_ggml(sys.argv[1])
print(hparams)
model = QuantizedLlama(hparams, all_tensors)
print("model built, starting inference")

View File

@ -197,38 +197,40 @@ trait MapDType {
#[pymethods]
impl PyTensor {
#[new]
#[pyo3(text_signature = "(data:_ArrayLike)")]
// TODO: Handle arbitrary input dtype and shape.
fn new(py: Python<'_>, vs: PyObject) -> PyResult<Self> {
/// Creates a new tensor from a Python value. The value can be a scalar or array-like object.
fn new(py: Python<'_>, data: PyObject) -> PyResult<Self> {
use Device::Cpu;
let tensor = if let Ok(vs) = vs.extract::<u32>(py) {
let tensor = if let Ok(vs) = data.extract::<u32>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
} else if let Ok(vs) = vs.extract::<i64>(py) {
} else if let Ok(vs) = data.extract::<i64>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
} else if let Ok(vs) = vs.extract::<f32>(py) {
} else if let Ok(vs) = data.extract::<f32>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
} else if let Ok(vs) = vs.extract::<Vec<u32>>(py) {
} else if let Ok(vs) = data.extract::<Vec<u32>>(py) {
let len = vs.len();
Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)?
} else if let Ok(vs) = vs.extract::<Vec<i64>>(py) {
} else if let Ok(vs) = data.extract::<Vec<i64>>(py) {
let len = vs.len();
Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)?
} else if let Ok(vs) = vs.extract::<Vec<f32>>(py) {
} else if let Ok(vs) = data.extract::<Vec<f32>>(py) {
let len = vs.len();
Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)?
} else if let Ok(vs) = vs.extract::<Vec<Vec<u32>>>(py) {
} else if let Ok(vs) = data.extract::<Vec<Vec<u32>>>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
} else if let Ok(vs) = vs.extract::<Vec<Vec<i64>>>(py) {
} else if let Ok(vs) = data.extract::<Vec<Vec<i64>>>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
} else if let Ok(vs) = vs.extract::<Vec<Vec<f32>>>(py) {
} else if let Ok(vs) = data.extract::<Vec<Vec<f32>>>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
} else if let Ok(vs) = vs.extract::<Vec<Vec<Vec<u32>>>>(py) {
} else if let Ok(vs) = data.extract::<Vec<Vec<Vec<u32>>>>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
} else if let Ok(vs) = vs.extract::<Vec<Vec<Vec<i64>>>>(py) {
} else if let Ok(vs) = data.extract::<Vec<Vec<Vec<i64>>>>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
} else if let Ok(vs) = vs.extract::<Vec<Vec<Vec<f32>>>>(py) {
} else if let Ok(vs) = data.extract::<Vec<Vec<Vec<f32>>>>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
} else {
let ty = vs.as_ref(py).get_type();
let ty = data.as_ref(py).get_type();
Err(PyTypeError::new_err(format!(
"incorrect type {ty} for tensor"
)))?
@ -236,7 +238,7 @@ impl PyTensor {
Ok(Self(tensor))
}
/// Gets the tensor data as a Python value/array/array of array/...
/// Gets the tensor's data as a Python scalar or array-like object.
fn values(&self, py: Python<'_>) -> PyResult<PyObject> {
struct M<'a>(Python<'a>);
impl<'a> MapDType for M<'a> {
@ -280,6 +282,7 @@ impl PyTensor {
}
#[getter]
/// Gets the tensor shape as a Python tuple.
fn shape(&self, py: Python<'_>) -> PyObject {
PyTuple::new(py, self.0.dims()).to_object(py)
}
@ -580,8 +583,9 @@ impl PyTensor {
}
}
/// Concatenate the tensors across one axis.
#[pyfunction]
#[pyo3(text_signature = "(tensors:List[Tensor], dim:int )")]
/// Concatenate the tensors across one axis.
fn cat(tensors: Vec<PyTensor>, dim: i64) -> PyResult<PyTensor> {
if tensors.is_empty() {
return Err(PyErr::new::<PyValueError, _>("empty input to cat"));
@ -593,6 +597,8 @@ fn cat(tensors: Vec<PyTensor>, dim: i64) -> PyResult<PyTensor> {
}
#[pyfunction]
#[pyo3(text_signature = "(tensors:List[Tensor], dim:int)")]
/// Stack the tensors along a new axis.
fn stack(tensors: Vec<PyTensor>, dim: usize) -> PyResult<PyTensor> {
let tensors = tensors.into_iter().map(|t| t.0).collect::<Vec<_>>();
let tensor = Tensor::stack(&tensors, dim).map_err(wrap_err)?;
@ -600,12 +606,15 @@ fn stack(tensors: Vec<PyTensor>, dim: usize) -> PyResult<PyTensor> {
}
#[pyfunction]
fn tensor(py: Python<'_>, vs: PyObject) -> PyResult<PyTensor> {
PyTensor::new(py, vs)
#[pyo3(text_signature = "(data:_ArrayLike)")]
/// Creates a new tensor from a Python value. The value can be a scalar or array-like object.
fn tensor(py: Python<'_>, data: PyObject) -> PyResult<PyTensor> {
PyTensor::new(py, data)
}
#[pyfunction]
#[pyo3(signature = (shape, *, device=None))]
#[pyo3(signature = (shape, *, device=None), text_signature = "(shape:Sequence[int], device:Optional[Device]=None)")]
/// Creates a new tensor with random values.
fn rand(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> {
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
let tensor = Tensor::rand(0f32, 1f32, shape.0, &device).map_err(wrap_err)?;
@ -613,7 +622,7 @@ fn rand(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<P
}
#[pyfunction]
#[pyo3(signature = (shape, *, device=None))]
#[pyo3(signature = (shape, *, device=None), text_signature = "(shape:Sequence[int], device:Optional[Device]=None)")]
fn randn(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> {
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
let tensor = Tensor::randn(0f32, 1f32, shape.0, &device).map_err(wrap_err)?;
@ -621,7 +630,7 @@ fn randn(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<
}
#[pyfunction]
#[pyo3(signature = (shape, *, dtype=None, device=None))]
#[pyo3(signature = (shape, *, dtype=None, device=None),text_signature = "(shape:Sequence[int], dtype:Optional[DType]=None, device:Optional[Device]=None)")]
fn ones(
py: Python<'_>,
shape: PyShape,
@ -638,7 +647,7 @@ fn ones(
}
#[pyfunction]
#[pyo3(signature = (shape, *, dtype=None, device=None))]
#[pyo3(signature = (shape, *, dtype=None, device=None), text_signature = "(shape:Sequence[int], dtype:Optional[DType]=None, device:Optional[Device]=None)")]
fn zeros(
py: Python<'_>,
shape: PyShape,
@ -704,6 +713,8 @@ impl PyQTensor {
}
#[pyfunction]
#[pyo3(text_signature = "(path:Union[str,PathLike])")]
/// Loads a safetensors file. Returns a dictionary mapping tensor names to tensors.
fn load_safetensors(path: &str, py: Python<'_>) -> PyResult<PyObject> {
let res = ::candle::safetensors::load(path, &Device::Cpu).map_err(wrap_err)?;
let res = res
@ -714,6 +725,8 @@ fn load_safetensors(path: &str, py: Python<'_>) -> PyResult<PyObject> {
}
#[pyfunction]
#[pyo3(text_signature = "(path:Union[str,PathLike], tensors:Dict[str,Tensor])")]
/// Saves a dictionary of tensors to a safetensors file.
fn save_safetensors(
path: &str,
tensors: std::collections::HashMap<String, PyTensor>,
@ -726,6 +739,9 @@ fn save_safetensors(
}
#[pyfunction]
#[pyo3(text_signature = "(path:Union[str,PathLike])")]
/// Load a GGML file. Returns a tuple of three objects: a dictionary mapping tensor names to tensors,
/// a dictionary mapping hyperparameter names to hyperparameter values, and a vocabulary.
fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject, PyObject)> {
let mut file = std::fs::File::open(path)?;
let ggml = ::candle::quantized::ggml_file::Content::read(&mut file).map_err(wrap_err)?;
@ -757,6 +773,9 @@ fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject, PyObje
}
#[pyfunction]
#[pyo3(text_signature = "(path:Union[str,PathLike])")]
/// Loads a GGUF file. Returns a tuple of two dictionaries: the first maps tensor names to tensors,
/// and the second maps metadata keys to metadata values.
fn load_gguf(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> {
use ::candle::quantized::gguf_file;
fn gguf_value_to_pyobject(v: &gguf_file::Value, py: Python<'_>) -> PyResult<PyObject> {
@ -806,21 +825,25 @@ fn load_gguf(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> {
}
#[pyfunction]
/// Returns true if the 'cuda' backend is available.
fn cuda_is_available() -> bool {
::candle::utils::cuda_is_available()
}
#[pyfunction]
/// Returns true if candle was compiled with 'accelerate' support.
fn has_accelerate() -> bool {
::candle::utils::has_accelerate()
}
#[pyfunction]
/// Returns true if candle was compiled with MKL support.
fn has_mkl() -> bool {
::candle::utils::has_mkl()
}
#[pyfunction]
/// Returns the number of threads used by the candle.
fn get_num_threads() -> usize {
::candle::utils::get_num_threads()
}
@ -830,19 +853,27 @@ fn candle_utils(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(get_num_threads, m)?)?;
m.add_function(wrap_pyfunction!(has_accelerate, m)?)?;
m.add_function(wrap_pyfunction!(has_mkl, m)?)?;
m.add_function(wrap_pyfunction!(load_ggml, m)?)?;
m.add_function(wrap_pyfunction!(load_gguf, m)?)?;
m.add_function(wrap_pyfunction!(load_safetensors, m)?)?;
m.add_function(wrap_pyfunction!(save_safetensors, m)?)?;
Ok(())
}
#[pyfunction]
fn softmax(t: PyTensor, dim: i64) -> PyResult<PyTensor> {
let dim = actual_dim(&t, dim).map_err(wrap_err)?;
let sm = candle_nn::ops::softmax(&t.0, dim).map_err(wrap_err)?;
#[pyo3(text_signature = "(tensor:Tensor, dim:int)")]
/// Applies the Softmax function to a given tensor.
fn softmax(tensor: PyTensor, dim: i64) -> PyResult<PyTensor> {
let dim = actual_dim(&tensor, dim).map_err(wrap_err)?;
let sm = candle_nn::ops::softmax(&tensor.0, dim).map_err(wrap_err)?;
Ok(PyTensor(sm))
}
#[pyfunction]
fn silu(t: PyTensor) -> PyResult<PyTensor> {
let s = candle_nn::ops::silu(&t.0).map_err(wrap_err)?;
#[pyo3(text_signature = "(tensor:Tensor)")]
/// Applies the Sigmoid Linear Unit (SiLU) function to a given tensor.
fn silu(tensor: PyTensor) -> PyResult<PyTensor> {
let s = candle_nn::ops::silu(&tensor.0).map_err(wrap_err)?;
Ok(PyTensor(s))
}
@ -871,14 +902,10 @@ fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add("f32", PyDType(DType::F32))?;
m.add("f64", PyDType(DType::F64))?;
m.add_function(wrap_pyfunction!(cat, m)?)?;
m.add_function(wrap_pyfunction!(load_ggml, m)?)?;
m.add_function(wrap_pyfunction!(load_gguf, m)?)?;
m.add_function(wrap_pyfunction!(load_safetensors, m)?)?;
m.add_function(wrap_pyfunction!(ones, m)?)?;
m.add_function(wrap_pyfunction!(rand, m)?)?;
m.add_function(wrap_pyfunction!(randn, m)?)?;
m.add_function(wrap_pyfunction!(tensor, m)?)?;
m.add_function(wrap_pyfunction!(save_safetensors, m)?)?;
m.add_function(wrap_pyfunction!(stack, m)?)?;
m.add_function(wrap_pyfunction!(zeros, m)?)?;
Ok(())

217
candle-pyo3/stub.py Normal file
View File

@ -0,0 +1,217 @@
#See: https://raw.githubusercontent.com/huggingface/tokenizers/main/bindings/python/stub.py
import argparse
import inspect
import os
from typing import Optional
import black
from pathlib import Path
INDENT = " " * 4
GENERATED_COMMENT = "# Generated content DO NOT EDIT\n"
TYPING = """from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
from os import PathLike
"""
CANDLE_SPECIFIC_TYPING = "from candle.typing import _ArrayLike, Device\n"
CANDLE_TENSOR_IMPORTS = "from candle import Tensor,DType\n"
def do_indent(text: Optional[str], indent: str):
if text is None:
return ""
return text.replace("\n", f"\n{indent}")
def function(obj, indent:str, text_signature:str=None):
if text_signature is None:
text_signature = obj.__text_signature__
text_signature = text_signature.replace("$self", "self").lstrip().rstrip()
string = ""
string += f"{indent}def {obj.__name__}{text_signature}:\n"
indent += INDENT
string += f'{indent}"""\n'
string += f"{indent}{do_indent(obj.__doc__, indent)}\n"
string += f'{indent}"""\n'
string += f"{indent}pass\n"
string += "\n"
string += "\n"
return string
def member_sort(member):
if inspect.isclass(member):
value = 10 + len(inspect.getmro(member))
else:
value = 1
return value
def fn_predicate(obj):
value = inspect.ismethoddescriptor(obj) or inspect.isbuiltin(obj)
if value:
return obj.__text_signature__ and not obj.__name__.startswith("_")
if inspect.isgetsetdescriptor(obj):
return not obj.__name__.startswith("_")
return False
def get_module_members(module):
members = [
member
for name, member in inspect.getmembers(module)
if not name.startswith("_") and not inspect.ismodule(member)
]
members.sort(key=member_sort)
return members
def pyi_file(obj, indent=""):
string = ""
if inspect.ismodule(obj):
string += GENERATED_COMMENT
string += TYPING
string += CANDLE_SPECIFIC_TYPING
if obj.__name__ != "candle.candle":
string += CANDLE_TENSOR_IMPORTS
members = get_module_members(obj)
for member in members:
string += pyi_file(member, indent)
elif inspect.isclass(obj):
indent += INDENT
mro = inspect.getmro(obj)
if len(mro) > 2:
inherit = f"({mro[1].__name__})"
else:
inherit = ""
string += f"class {obj.__name__}{inherit}:\n"
body = ""
if obj.__doc__:
body += f'{indent}"""\n{indent}{do_indent(obj.__doc__, indent)}\n{indent}"""\n'
fns = inspect.getmembers(obj, fn_predicate)
# Init
if obj.__text_signature__:
body += f"{indent}def __init__{obj.__text_signature__}:\n"
body += f"{indent+INDENT}pass\n"
body += "\n"
for (name, fn) in fns:
body += pyi_file(fn, indent=indent)
if not body:
body += f"{indent}pass\n"
string += body
string += "\n\n"
elif inspect.isbuiltin(obj):
string += f"{indent}@staticmethod\n"
string += function(obj, indent)
elif inspect.ismethoddescriptor(obj):
string += function(obj, indent)
elif inspect.isgetsetdescriptor(obj):
# TODO it would be interesing to add the setter maybe ?
string += f"{indent}@property\n"
string += function(obj, indent, text_signature="(self)")
elif obj.__class__.__name__ == "DType":
string += f"class {str(obj).lower()}(DType):\n"
string += f"{indent+INDENT}pass\n"
else:
raise Exception(f"Object {obj} is not supported")
return string
def py_file(module, origin):
members = get_module_members(module)
string = GENERATED_COMMENT
string += f"from .. import {origin}\n"
string += "\n"
for member in members:
if hasattr(member, "__name__"):
name = member.__name__
else:
name = str(member)
string += f"{name} = {origin}.{name}\n"
return string
def do_black(content, is_pyi):
mode = black.Mode(
target_versions={black.TargetVersion.PY35},
line_length=119,
is_pyi=is_pyi,
string_normalization=True,
experimental_string_processing=False,
)
try:
return black.format_file_contents(content, fast=True, mode=mode)
except black.NothingChanged:
return content
def write(module, directory, origin, check=False):
submodules = [(name, member) for name, member in inspect.getmembers(module) if inspect.ismodule(member)]
filename = os.path.join(directory, "__init__.pyi")
pyi_content = pyi_file(module)
pyi_content = do_black(pyi_content, is_pyi=True)
os.makedirs(directory, exist_ok=True)
if check:
with open(filename, "r") as f:
data = f.read()
assert data == pyi_content, f"The content of {filename} seems outdated, please run `python stub.py`"
else:
with open(filename, "w") as f:
f.write(pyi_content)
filename = os.path.join(directory, "__init__.py")
py_content = py_file(module, origin)
py_content = do_black(py_content, is_pyi=False)
os.makedirs(directory, exist_ok=True)
is_auto = False
if not os.path.exists(filename):
is_auto = True
else:
with open(filename, "r") as f:
line = f.readline()
if line == GENERATED_COMMENT:
is_auto = True
if is_auto:
if check:
with open(filename, "r") as f:
data = f.read()
assert data == py_content, f"The content of {filename} seems outdated, please run `python stub.py`"
else:
with open(filename, "w") as f:
f.write(py_content)
for name, submodule in submodules:
write(submodule, os.path.join(directory, name), f"{name}", check=check)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--check", action="store_true")
args = parser.parse_args()
#Enable execution from the candle and candle-pyo3 directories
cwd = Path.cwd()
directory = "py_src/candle/"
if cwd.name != "candle-pyo3":
directory = f"candle-pyo3/{directory}"
import candle
write(candle.candle, directory, "candle", check=args.check)

View File

@ -1,4 +1,5 @@
import candle
from candle import Tensor, QTensor
t = candle.Tensor(42.0)
print(t)
@ -9,7 +10,7 @@ t = candle.Tensor([3.0, 1, 4, 1, 5, 9, 2, 6])
print(t)
print(t+t)
t = t.reshape([2, 4])
t:Tensor = t.reshape([2, 4])
print(t.matmul(t.t()))
print(t.to_dtype(candle.u8))
@ -20,7 +21,7 @@ print(t)
print(t.dtype)
t = candle.randn((16, 256))
quant_t = t.quantize("q6k")
dequant_t = quant_t.dequantize()
diff2 = (t - dequant_t).sqr()
quant_t:QTensor = t.quantize("q6k")
dequant_t:Tensor = quant_t.dequantize()
diff2:Tensor = (t - dequant_t).sqr()
print(diff2.mean_all())