diff --git a/candle-pyo3/.gitignore b/candle-pyo3/.gitignore new file mode 100644 index 00000000..68bc17f9 --- /dev/null +++ b/candle-pyo3/.gitignore @@ -0,0 +1,160 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ diff --git a/candle-pyo3/Cargo.toml b/candle-pyo3/Cargo.toml index 1d431cfb..c96681bd 100644 --- a/candle-pyo3/Cargo.toml +++ b/candle-pyo3/Cargo.toml @@ -12,7 +12,6 @@ readme = "README.md" [lib] name = "candle" crate-type = ["cdylib"] -doc = false [dependencies] candle = { path = "../candle-core", version = "0.2.2", package = "candle-core" } diff --git a/candle-pyo3/README.md b/candle-pyo3/README.md index 07dff468..be6d4f68 100644 --- a/candle-pyo3/README.md +++ b/candle-pyo3/README.md @@ -1,7 +1,26 @@ +## Installation + From the `candle-pyo3` directory, enable a virtual env where you will want the candle package to be installed then run. ```bash -maturin develop +maturin develop -r python test.py ``` + +## Generating Stub Files for Type Hinting + +For type hinting support, the `candle-pyo3` package requires `*.pyi` files. You can automatically generate these files using the `stub.py` script. + +### Steps: +1. Install the package using `maturin`. +2. Generate the stub files by running: + ``` + python stub.py + ``` + +### Validation: +To ensure that the stub files match the current implementation, execute: +``` +python stub.py --check +``` diff --git a/candle-pyo3/py_src/candle/__init__.py b/candle-pyo3/py_src/candle/__init__.py new file mode 100644 index 00000000..49c96122 --- /dev/null +++ b/candle-pyo3/py_src/candle/__init__.py @@ -0,0 +1 @@ +from .candle import * \ No newline at end of file diff --git a/candle-pyo3/py_src/candle/__init__.pyi b/candle-pyo3/py_src/candle/__init__.pyi new file mode 100644 index 00000000..c21e6738 --- /dev/null +++ b/candle-pyo3/py_src/candle/__init__.pyi @@ -0,0 +1,248 @@ +# Generated content DO NOT EDIT +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence +from os import PathLike +from candle.typing import _ArrayLike, Device + +class bf16(DType): + pass + +@staticmethod +def cat(tensors: List[Tensor], dim: int): + """ + Concatenate the tensors across one axis. + """ + pass + +class f16(DType): + pass + +class f32(DType): + pass + +class f64(DType): + pass + +class i64(DType): + pass + +@staticmethod +def ones(shape: Sequence[int], dtype: Optional[DType] = None, device: Optional[Device] = None): + """ """ + pass + +@staticmethod +def rand(shape: Sequence[int], device: Optional[Device] = None): + """ + Creates a new tensor with random values. + """ + pass + +@staticmethod +def randn(shape: Sequence[int], device: Optional[Device] = None): + """ """ + pass + +@staticmethod +def stack(tensors: List[Tensor], dim: int): + """ + Stack the tensors along a new axis. + """ + pass + +@staticmethod +def tensor(data: _ArrayLike): + """ + Creates a new tensor from a Python value. The value can be a scalar or array-like object. + """ + pass + +class u32(DType): + pass + +class u8(DType): + pass + +@staticmethod +def zeros(shape: Sequence[int], dtype: Optional[DType] = None, device: Optional[Device] = None): + """ """ + pass + +class DType: + pass + +class QTensor: + def dequantize(self): + """ """ + pass + @property + def ggml_dtype(self): + """ """ + pass + def matmul_t(self, lhs): + """ """ + pass + @property + def rank(self): + """ """ + pass + @property + def shape(self): + """ """ + pass + +class Tensor: + def __init__(data: _ArrayLike): + pass + def argmax_keepdim(self, dim): + """ """ + pass + def argmin_keepdim(self, dim): + """ """ + pass + def broadcast_add(self, rhs): + """ """ + pass + def broadcast_as(self, shape): + """ """ + pass + def broadcast_div(self, rhs): + """ """ + pass + def broadcast_left(self, shape): + """ """ + pass + def broadcast_mul(self, rhs): + """ """ + pass + def broadcast_sub(self, rhs): + """ """ + pass + def contiguous(self): + """ """ + pass + def copy(self): + """ """ + pass + def cos(self): + """ """ + pass + def detach(self): + """ """ + pass + @property + def device(self): + """ """ + pass + @property + def dtype(self): + """ """ + pass + def exp(self): + """ """ + pass + def flatten_all(self): + """ """ + pass + def flatten_from(self, dim): + """ """ + pass + def flatten_to(self, dim): + """ """ + pass + def get(self, index): + """ """ + pass + def index_select(self, rhs, dim): + """ """ + pass + def is_contiguous(self): + """ """ + pass + def is_fortran_contiguous(self): + """ """ + pass + def log(self): + """ """ + pass + def matmul(self, rhs): + """ """ + pass + def max_keepdim(self, dim): + """ """ + pass + def mean_all(self): + """ """ + pass + def min_keepdim(self, dim): + """ """ + pass + def narrow(self, dim, start, len): + """ """ + pass + def powf(self, p): + """ """ + pass + def quantize(self, quantized_dtype): + """ """ + pass + @property + def rank(self): + """ """ + pass + def recip(self): + """ """ + pass + def reshape(self, shape): + """ """ + pass + @property + def shape(self): + """ + Gets the tensor shape as a Python tuple. + """ + pass + def sin(self): + """ """ + pass + def sqr(self): + """ """ + pass + def sqrt(self): + """ """ + pass + def squeeze(self, dim): + """ """ + pass + @property + def stride(self): + """ """ + pass + def sum_all(self): + """ """ + pass + def sum_keepdim(self, dims): + """ """ + pass + def t(self): + """ """ + pass + def to_device(self, device): + """ """ + pass + def to_dtype(self, dtype): + """ """ + pass + def transpose(self, dim1, dim2): + """ """ + pass + def unsqueeze(self, dim): + """ """ + pass + def values(self): + """ + Gets the tensor's data as a Python scalar or array-like object. + """ + pass + def where_cond(self, on_true, on_false): + """ """ + pass diff --git a/candle-pyo3/py_src/candle/nn/__init__.py b/candle-pyo3/py_src/candle/nn/__init__.py new file mode 100644 index 00000000..b8c5cfb7 --- /dev/null +++ b/candle-pyo3/py_src/candle/nn/__init__.py @@ -0,0 +1,5 @@ +# Generated content DO NOT EDIT +from .. import nn + +silu = nn.silu +softmax = nn.softmax diff --git a/candle-pyo3/py_src/candle/nn/__init__.pyi b/candle-pyo3/py_src/candle/nn/__init__.pyi new file mode 100644 index 00000000..821cd052 --- /dev/null +++ b/candle-pyo3/py_src/candle/nn/__init__.pyi @@ -0,0 +1,19 @@ +# Generated content DO NOT EDIT +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence +from os import PathLike +from candle.typing import _ArrayLike, Device +from candle import Tensor, DType + +@staticmethod +def silu(tensor: Tensor): + """ + Applies the Sigmoid Linear Unit (SiLU) function to a given tensor. + """ + pass + +@staticmethod +def softmax(tensor: Tensor, dim: int): + """ + Applies the Softmax function to a given tensor. + """ + pass diff --git a/candle-pyo3/py_src/candle/typing/__init__.py b/candle-pyo3/py_src/candle/typing/__init__.py new file mode 100644 index 00000000..ea85d2a3 --- /dev/null +++ b/candle-pyo3/py_src/candle/typing/__init__.py @@ -0,0 +1,16 @@ +from typing import TypeVar, Union, Sequence + +_T = TypeVar("_T") + +_ArrayLike = Union[ + _T, + Sequence[_T], + Sequence[Sequence[_T]], + Sequence[Sequence[Sequence[_T]]], + Sequence[Sequence[Sequence[Sequence[_T]]]], +] + +CPU:str = "cpu" +CUDA:str = "cuda" + +Device = TypeVar("Device", CPU, CUDA) \ No newline at end of file diff --git a/candle-pyo3/py_src/candle/utils/__init__.py b/candle-pyo3/py_src/candle/utils/__init__.py new file mode 100644 index 00000000..2ead6d84 --- /dev/null +++ b/candle-pyo3/py_src/candle/utils/__init__.py @@ -0,0 +1,11 @@ +# Generated content DO NOT EDIT +from .. import utils + +cuda_is_available = utils.cuda_is_available +get_num_threads = utils.get_num_threads +has_accelerate = utils.has_accelerate +has_mkl = utils.has_mkl +load_ggml = utils.load_ggml +load_gguf = utils.load_gguf +load_safetensors = utils.load_safetensors +save_safetensors = utils.save_safetensors diff --git a/candle-pyo3/py_src/candle/utils/__init__.pyi b/candle-pyo3/py_src/candle/utils/__init__.pyi new file mode 100644 index 00000000..7a0a5231 --- /dev/null +++ b/candle-pyo3/py_src/candle/utils/__init__.pyi @@ -0,0 +1,63 @@ +# Generated content DO NOT EDIT +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence +from os import PathLike +from candle.typing import _ArrayLike, Device +from candle import Tensor, DType + +@staticmethod +def cuda_is_available(): + """ + Returns true if the 'cuda' backend is available. + """ + pass + +@staticmethod +def get_num_threads(): + """ + Returns the number of threads used by the candle. + """ + pass + +@staticmethod +def has_accelerate(): + """ + Returns true if candle was compiled with 'accelerate' support. + """ + pass + +@staticmethod +def has_mkl(): + """ + Returns true if candle was compiled with MKL support. + """ + pass + +@staticmethod +def load_ggml(path: Union[str, PathLike]): + """ + Load a GGML file. Returns a tuple of three objects: a dictionary mapping tensor names to tensors, + a dictionary mapping hyperparameter names to hyperparameter values, and a vocabulary. + """ + pass + +@staticmethod +def load_gguf(path: Union[str, PathLike]): + """ + Loads a GGUF file. Returns a tuple of two dictionaries: the first maps tensor names to tensors, + and the second maps metadata keys to metadata values. + """ + pass + +@staticmethod +def load_safetensors(path: Union[str, PathLike]): + """ + Loads a safetensors file. Returns a dictionary mapping tensor names to tensors. + """ + pass + +@staticmethod +def save_safetensors(path: Union[str, PathLike], tensors: Dict[str, Tensor]): + """ + Saves a dictionary of tensors to a safetensors file. + """ + pass diff --git a/candle-pyo3/pyproject.toml b/candle-pyo3/pyproject.toml new file mode 100644 index 00000000..b4e372d7 --- /dev/null +++ b/candle-pyo3/pyproject.toml @@ -0,0 +1,30 @@ +[project] +name = 'candle-pyo3' +requires-python = '>=3.7' +authors = [ + {name = 'Laurent Mazare', email = ''}, +] + +dynamic = [ + 'description', + 'license', + 'readme', +] + +[project.urls] +Homepage = 'https://github.com/huggingface/candle' +Source = 'https://github.com/huggingface/candle' + +[build-system] +requires = ["maturin>=1.0,<2.0"] +build-backend = "maturin" + +[tool.maturin] +python-source = "py_src" +module-name = "candle.candle" +bindings = 'pyo3' +features = ["pyo3/extension-module"] + +[tool.black] +line-length = 119 +target-version = ['py35'] \ No newline at end of file diff --git a/candle-pyo3/quant-llama.py b/candle-pyo3/quant-llama.py index 0f7a51c6..020d525d 100644 --- a/candle-pyo3/quant-llama.py +++ b/candle-pyo3/quant-llama.py @@ -1,6 +1,7 @@ # This example shows how the candle Python api can be used to replicate llama.cpp. import sys import candle +from candle.utils import load_ggml,load_gguf MAX_SEQ_LEN = 4096 @@ -154,7 +155,7 @@ def main(): filename = sys.argv[1] print(f"reading model file {filename}") if filename.endswith("gguf"): - all_tensors, metadata = candle.load_gguf(sys.argv[1]) + all_tensors, metadata = load_gguf(sys.argv[1]) vocab = metadata["tokenizer.ggml.tokens"] for i, v in enumerate(vocab): vocab[i] = '\n' if v == '<0x0A>' else v.replace('▁', ' ') @@ -168,13 +169,13 @@ def main(): 'n_head_kv': metadata['llama.attention.head_count_kv'], 'n_layer': metadata['llama.block_count'], 'n_rot': metadata['llama.rope.dimension_count'], - 'rope_freq': metadata['llama.rope.freq_base'], + 'rope_freq': metadata.get('llama.rope.freq_base', 10000.), 'ftype': metadata['general.file_type'], } all_tensors = { gguf_rename(k): v for k, v in all_tensors.items() } else: - all_tensors, hparams, vocab = candle.load_ggml(sys.argv[1]) + all_tensors, hparams, vocab = load_ggml(sys.argv[1]) print(hparams) model = QuantizedLlama(hparams, all_tensors) print("model built, starting inference") diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index eddc0fda..1df78ec6 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -197,38 +197,40 @@ trait MapDType { #[pymethods] impl PyTensor { #[new] + #[pyo3(text_signature = "(data:_ArrayLike)")] // TODO: Handle arbitrary input dtype and shape. - fn new(py: Python<'_>, vs: PyObject) -> PyResult { + /// Creates a new tensor from a Python value. The value can be a scalar or array-like object. + fn new(py: Python<'_>, data: PyObject) -> PyResult { use Device::Cpu; - let tensor = if let Ok(vs) = vs.extract::(py) { + let tensor = if let Ok(vs) = data.extract::(py) { Tensor::new(vs, &Cpu).map_err(wrap_err)? - } else if let Ok(vs) = vs.extract::(py) { + } else if let Ok(vs) = data.extract::(py) { Tensor::new(vs, &Cpu).map_err(wrap_err)? - } else if let Ok(vs) = vs.extract::(py) { + } else if let Ok(vs) = data.extract::(py) { Tensor::new(vs, &Cpu).map_err(wrap_err)? - } else if let Ok(vs) = vs.extract::>(py) { + } else if let Ok(vs) = data.extract::>(py) { let len = vs.len(); Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)? - } else if let Ok(vs) = vs.extract::>(py) { + } else if let Ok(vs) = data.extract::>(py) { let len = vs.len(); Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)? - } else if let Ok(vs) = vs.extract::>(py) { + } else if let Ok(vs) = data.extract::>(py) { let len = vs.len(); Tensor::from_vec(vs, len, &Cpu).map_err(wrap_err)? - } else if let Ok(vs) = vs.extract::>>(py) { + } else if let Ok(vs) = data.extract::>>(py) { Tensor::new(vs, &Cpu).map_err(wrap_err)? - } else if let Ok(vs) = vs.extract::>>(py) { + } else if let Ok(vs) = data.extract::>>(py) { Tensor::new(vs, &Cpu).map_err(wrap_err)? - } else if let Ok(vs) = vs.extract::>>(py) { + } else if let Ok(vs) = data.extract::>>(py) { Tensor::new(vs, &Cpu).map_err(wrap_err)? - } else if let Ok(vs) = vs.extract::>>>(py) { + } else if let Ok(vs) = data.extract::>>>(py) { Tensor::new(vs, &Cpu).map_err(wrap_err)? - } else if let Ok(vs) = vs.extract::>>>(py) { + } else if let Ok(vs) = data.extract::>>>(py) { Tensor::new(vs, &Cpu).map_err(wrap_err)? - } else if let Ok(vs) = vs.extract::>>>(py) { + } else if let Ok(vs) = data.extract::>>>(py) { Tensor::new(vs, &Cpu).map_err(wrap_err)? } else { - let ty = vs.as_ref(py).get_type(); + let ty = data.as_ref(py).get_type(); Err(PyTypeError::new_err(format!( "incorrect type {ty} for tensor" )))? @@ -236,7 +238,7 @@ impl PyTensor { Ok(Self(tensor)) } - /// Gets the tensor data as a Python value/array/array of array/... + /// Gets the tensor's data as a Python scalar or array-like object. fn values(&self, py: Python<'_>) -> PyResult { struct M<'a>(Python<'a>); impl<'a> MapDType for M<'a> { @@ -280,6 +282,7 @@ impl PyTensor { } #[getter] + /// Gets the tensor shape as a Python tuple. fn shape(&self, py: Python<'_>) -> PyObject { PyTuple::new(py, self.0.dims()).to_object(py) } @@ -580,8 +583,9 @@ impl PyTensor { } } -/// Concatenate the tensors across one axis. #[pyfunction] +#[pyo3(text_signature = "(tensors:List[Tensor], dim:int )")] +/// Concatenate the tensors across one axis. fn cat(tensors: Vec, dim: i64) -> PyResult { if tensors.is_empty() { return Err(PyErr::new::("empty input to cat")); @@ -593,6 +597,8 @@ fn cat(tensors: Vec, dim: i64) -> PyResult { } #[pyfunction] +#[pyo3(text_signature = "(tensors:List[Tensor], dim:int)")] +/// Stack the tensors along a new axis. fn stack(tensors: Vec, dim: usize) -> PyResult { let tensors = tensors.into_iter().map(|t| t.0).collect::>(); let tensor = Tensor::stack(&tensors, dim).map_err(wrap_err)?; @@ -600,12 +606,15 @@ fn stack(tensors: Vec, dim: usize) -> PyResult { } #[pyfunction] -fn tensor(py: Python<'_>, vs: PyObject) -> PyResult { - PyTensor::new(py, vs) +#[pyo3(text_signature = "(data:_ArrayLike)")] +/// Creates a new tensor from a Python value. The value can be a scalar or array-like object. +fn tensor(py: Python<'_>, data: PyObject) -> PyResult { + PyTensor::new(py, data) } #[pyfunction] -#[pyo3(signature = (shape, *, device=None))] +#[pyo3(signature = (shape, *, device=None), text_signature = "(shape:Sequence[int], device:Optional[Device]=None)")] +/// Creates a new tensor with random values. fn rand(_py: Python<'_>, shape: PyShape, device: Option) -> PyResult { let device = device.unwrap_or(PyDevice::Cpu).as_device()?; let tensor = Tensor::rand(0f32, 1f32, shape.0, &device).map_err(wrap_err)?; @@ -613,7 +622,7 @@ fn rand(_py: Python<'_>, shape: PyShape, device: Option) -> PyResult

, shape: PyShape, device: Option) -> PyResult { let device = device.unwrap_or(PyDevice::Cpu).as_device()?; let tensor = Tensor::randn(0f32, 1f32, shape.0, &device).map_err(wrap_err)?; @@ -621,7 +630,7 @@ fn randn(_py: Python<'_>, shape: PyShape, device: Option) -> PyResult< } #[pyfunction] -#[pyo3(signature = (shape, *, dtype=None, device=None))] +#[pyo3(signature = (shape, *, dtype=None, device=None),text_signature = "(shape:Sequence[int], dtype:Optional[DType]=None, device:Optional[Device]=None)")] fn ones( py: Python<'_>, shape: PyShape, @@ -638,7 +647,7 @@ fn ones( } #[pyfunction] -#[pyo3(signature = (shape, *, dtype=None, device=None))] +#[pyo3(signature = (shape, *, dtype=None, device=None), text_signature = "(shape:Sequence[int], dtype:Optional[DType]=None, device:Optional[Device]=None)")] fn zeros( py: Python<'_>, shape: PyShape, @@ -704,6 +713,8 @@ impl PyQTensor { } #[pyfunction] +#[pyo3(text_signature = "(path:Union[str,PathLike])")] +/// Loads a safetensors file. Returns a dictionary mapping tensor names to tensors. fn load_safetensors(path: &str, py: Python<'_>) -> PyResult { let res = ::candle::safetensors::load(path, &Device::Cpu).map_err(wrap_err)?; let res = res @@ -714,6 +725,8 @@ fn load_safetensors(path: &str, py: Python<'_>) -> PyResult { } #[pyfunction] +#[pyo3(text_signature = "(path:Union[str,PathLike], tensors:Dict[str,Tensor])")] +/// Saves a dictionary of tensors to a safetensors file. fn save_safetensors( path: &str, tensors: std::collections::HashMap, @@ -726,6 +739,9 @@ fn save_safetensors( } #[pyfunction] +#[pyo3(text_signature = "(path:Union[str,PathLike])")] +/// Load a GGML file. Returns a tuple of three objects: a dictionary mapping tensor names to tensors, +/// a dictionary mapping hyperparameter names to hyperparameter values, and a vocabulary. fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject, PyObject)> { let mut file = std::fs::File::open(path)?; let ggml = ::candle::quantized::ggml_file::Content::read(&mut file).map_err(wrap_err)?; @@ -757,6 +773,9 @@ fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject, PyObje } #[pyfunction] +#[pyo3(text_signature = "(path:Union[str,PathLike])")] +/// Loads a GGUF file. Returns a tuple of two dictionaries: the first maps tensor names to tensors, +/// and the second maps metadata keys to metadata values. fn load_gguf(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> { use ::candle::quantized::gguf_file; fn gguf_value_to_pyobject(v: &gguf_file::Value, py: Python<'_>) -> PyResult { @@ -806,21 +825,25 @@ fn load_gguf(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> { } #[pyfunction] +/// Returns true if the 'cuda' backend is available. fn cuda_is_available() -> bool { ::candle::utils::cuda_is_available() } #[pyfunction] +/// Returns true if candle was compiled with 'accelerate' support. fn has_accelerate() -> bool { ::candle::utils::has_accelerate() } #[pyfunction] +/// Returns true if candle was compiled with MKL support. fn has_mkl() -> bool { ::candle::utils::has_mkl() } #[pyfunction] +/// Returns the number of threads used by the candle. fn get_num_threads() -> usize { ::candle::utils::get_num_threads() } @@ -830,19 +853,27 @@ fn candle_utils(_py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(get_num_threads, m)?)?; m.add_function(wrap_pyfunction!(has_accelerate, m)?)?; m.add_function(wrap_pyfunction!(has_mkl, m)?)?; + m.add_function(wrap_pyfunction!(load_ggml, m)?)?; + m.add_function(wrap_pyfunction!(load_gguf, m)?)?; + m.add_function(wrap_pyfunction!(load_safetensors, m)?)?; + m.add_function(wrap_pyfunction!(save_safetensors, m)?)?; Ok(()) } #[pyfunction] -fn softmax(t: PyTensor, dim: i64) -> PyResult { - let dim = actual_dim(&t, dim).map_err(wrap_err)?; - let sm = candle_nn::ops::softmax(&t.0, dim).map_err(wrap_err)?; +#[pyo3(text_signature = "(tensor:Tensor, dim:int)")] +/// Applies the Softmax function to a given tensor. +fn softmax(tensor: PyTensor, dim: i64) -> PyResult { + let dim = actual_dim(&tensor, dim).map_err(wrap_err)?; + let sm = candle_nn::ops::softmax(&tensor.0, dim).map_err(wrap_err)?; Ok(PyTensor(sm)) } #[pyfunction] -fn silu(t: PyTensor) -> PyResult { - let s = candle_nn::ops::silu(&t.0).map_err(wrap_err)?; +#[pyo3(text_signature = "(tensor:Tensor)")] +/// Applies the Sigmoid Linear Unit (SiLU) function to a given tensor. +fn silu(tensor: PyTensor) -> PyResult { + let s = candle_nn::ops::silu(&tensor.0).map_err(wrap_err)?; Ok(PyTensor(s)) } @@ -871,14 +902,10 @@ fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add("f32", PyDType(DType::F32))?; m.add("f64", PyDType(DType::F64))?; m.add_function(wrap_pyfunction!(cat, m)?)?; - m.add_function(wrap_pyfunction!(load_ggml, m)?)?; - m.add_function(wrap_pyfunction!(load_gguf, m)?)?; - m.add_function(wrap_pyfunction!(load_safetensors, m)?)?; m.add_function(wrap_pyfunction!(ones, m)?)?; m.add_function(wrap_pyfunction!(rand, m)?)?; m.add_function(wrap_pyfunction!(randn, m)?)?; m.add_function(wrap_pyfunction!(tensor, m)?)?; - m.add_function(wrap_pyfunction!(save_safetensors, m)?)?; m.add_function(wrap_pyfunction!(stack, m)?)?; m.add_function(wrap_pyfunction!(zeros, m)?)?; Ok(()) diff --git a/candle-pyo3/stub.py b/candle-pyo3/stub.py new file mode 100644 index 00000000..b5b9256f --- /dev/null +++ b/candle-pyo3/stub.py @@ -0,0 +1,217 @@ +#See: https://raw.githubusercontent.com/huggingface/tokenizers/main/bindings/python/stub.py +import argparse +import inspect +import os +from typing import Optional +import black +from pathlib import Path + + +INDENT = " " * 4 +GENERATED_COMMENT = "# Generated content DO NOT EDIT\n" +TYPING = """from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence +from os import PathLike +""" +CANDLE_SPECIFIC_TYPING = "from candle.typing import _ArrayLike, Device\n" +CANDLE_TENSOR_IMPORTS = "from candle import Tensor,DType\n" + + + +def do_indent(text: Optional[str], indent: str): + if text is None: + return "" + return text.replace("\n", f"\n{indent}") + + +def function(obj, indent:str, text_signature:str=None): + if text_signature is None: + text_signature = obj.__text_signature__ + + text_signature = text_signature.replace("$self", "self").lstrip().rstrip() + string = "" + string += f"{indent}def {obj.__name__}{text_signature}:\n" + indent += INDENT + string += f'{indent}"""\n' + string += f"{indent}{do_indent(obj.__doc__, indent)}\n" + string += f'{indent}"""\n' + string += f"{indent}pass\n" + string += "\n" + string += "\n" + return string + + +def member_sort(member): + if inspect.isclass(member): + value = 10 + len(inspect.getmro(member)) + else: + value = 1 + return value + + +def fn_predicate(obj): + value = inspect.ismethoddescriptor(obj) or inspect.isbuiltin(obj) + if value: + return obj.__text_signature__ and not obj.__name__.startswith("_") + if inspect.isgetsetdescriptor(obj): + return not obj.__name__.startswith("_") + return False + + +def get_module_members(module): + members = [ + member + for name, member in inspect.getmembers(module) + if not name.startswith("_") and not inspect.ismodule(member) + ] + members.sort(key=member_sort) + return members + + +def pyi_file(obj, indent=""): + string = "" + if inspect.ismodule(obj): + string += GENERATED_COMMENT + string += TYPING + string += CANDLE_SPECIFIC_TYPING + if obj.__name__ != "candle.candle": + string += CANDLE_TENSOR_IMPORTS + members = get_module_members(obj) + for member in members: + string += pyi_file(member, indent) + + elif inspect.isclass(obj): + indent += INDENT + mro = inspect.getmro(obj) + if len(mro) > 2: + inherit = f"({mro[1].__name__})" + else: + inherit = "" + string += f"class {obj.__name__}{inherit}:\n" + + body = "" + if obj.__doc__: + body += f'{indent}"""\n{indent}{do_indent(obj.__doc__, indent)}\n{indent}"""\n' + + fns = inspect.getmembers(obj, fn_predicate) + + # Init + if obj.__text_signature__: + body += f"{indent}def __init__{obj.__text_signature__}:\n" + body += f"{indent+INDENT}pass\n" + body += "\n" + + for (name, fn) in fns: + body += pyi_file(fn, indent=indent) + + if not body: + body += f"{indent}pass\n" + + string += body + string += "\n\n" + + elif inspect.isbuiltin(obj): + string += f"{indent}@staticmethod\n" + string += function(obj, indent) + + elif inspect.ismethoddescriptor(obj): + string += function(obj, indent) + + elif inspect.isgetsetdescriptor(obj): + # TODO it would be interesing to add the setter maybe ? + string += f"{indent}@property\n" + string += function(obj, indent, text_signature="(self)") + + elif obj.__class__.__name__ == "DType": + string += f"class {str(obj).lower()}(DType):\n" + string += f"{indent+INDENT}pass\n" + else: + raise Exception(f"Object {obj} is not supported") + return string + + +def py_file(module, origin): + members = get_module_members(module) + + string = GENERATED_COMMENT + string += f"from .. import {origin}\n" + string += "\n" + for member in members: + if hasattr(member, "__name__"): + name = member.__name__ + else: + name = str(member) + string += f"{name} = {origin}.{name}\n" + return string + + +def do_black(content, is_pyi): + mode = black.Mode( + target_versions={black.TargetVersion.PY35}, + line_length=119, + is_pyi=is_pyi, + string_normalization=True, + experimental_string_processing=False, + ) + try: + return black.format_file_contents(content, fast=True, mode=mode) + except black.NothingChanged: + return content + + +def write(module, directory, origin, check=False): + submodules = [(name, member) for name, member in inspect.getmembers(module) if inspect.ismodule(member)] + + filename = os.path.join(directory, "__init__.pyi") + pyi_content = pyi_file(module) + pyi_content = do_black(pyi_content, is_pyi=True) + os.makedirs(directory, exist_ok=True) + if check: + with open(filename, "r") as f: + data = f.read() + assert data == pyi_content, f"The content of {filename} seems outdated, please run `python stub.py`" + else: + with open(filename, "w") as f: + f.write(pyi_content) + + filename = os.path.join(directory, "__init__.py") + py_content = py_file(module, origin) + py_content = do_black(py_content, is_pyi=False) + os.makedirs(directory, exist_ok=True) + + is_auto = False + if not os.path.exists(filename): + is_auto = True + else: + with open(filename, "r") as f: + line = f.readline() + if line == GENERATED_COMMENT: + is_auto = True + + if is_auto: + if check: + with open(filename, "r") as f: + data = f.read() + assert data == py_content, f"The content of {filename} seems outdated, please run `python stub.py`" + else: + with open(filename, "w") as f: + f.write(py_content) + + for name, submodule in submodules: + write(submodule, os.path.join(directory, name), f"{name}", check=check) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--check", action="store_true") + + args = parser.parse_args() + + #Enable execution from the candle and candle-pyo3 directories + cwd = Path.cwd() + directory = "py_src/candle/" + if cwd.name != "candle-pyo3": + directory = f"candle-pyo3/{directory}" + + import candle + + write(candle.candle, directory, "candle", check=args.check) diff --git a/candle-pyo3/test.py b/candle-pyo3/test.py index 7f24b49d..c78ffc41 100644 --- a/candle-pyo3/test.py +++ b/candle-pyo3/test.py @@ -1,4 +1,5 @@ import candle +from candle import Tensor, QTensor t = candle.Tensor(42.0) print(t) @@ -9,7 +10,7 @@ t = candle.Tensor([3.0, 1, 4, 1, 5, 9, 2, 6]) print(t) print(t+t) -t = t.reshape([2, 4]) +t:Tensor = t.reshape([2, 4]) print(t.matmul(t.t())) print(t.to_dtype(candle.u8)) @@ -20,7 +21,7 @@ print(t) print(t.dtype) t = candle.randn((16, 256)) -quant_t = t.quantize("q6k") -dequant_t = quant_t.dequantize() -diff2 = (t - dequant_t).sqr() +quant_t:QTensor = t.quantize("q6k") +dequant_t:Tensor = quant_t.dequantize() +diff2:Tensor = (t - dequant_t).sqr() print(diff2.mean_all())