From c05c0a8213eb5518901b2aa87503e8c6b65b9d0f Mon Sep 17 00:00:00 2001 From: Lukas Kreussel <65088241+LLukas22@users.noreply.github.com> Date: Mon, 30 Oct 2023 16:17:28 +0100 Subject: [PATCH] PyO3: Add `equal` and `__richcmp__` to `candle.Tensor` (#1099) * add `equal` to tensor * add `__richcmp__` support for tensors and scalars * typo * more typos * Add `abs` + `candle.testing` * remove duplicated `broadcast_shape_binary_op` * `candle.i16` => `candle.i64` * `tensor.nelements` -> `tensor.nelement` * Cleanup `abs` --- candle-core/src/shape.rs | 2 +- candle-pyo3/_additional_typing/__init__.py | 36 +++++++++ candle-pyo3/py_src/candle/__init__.pyi | 41 +++++++++++ candle-pyo3/py_src/candle/testing/__init__.py | 70 ++++++++++++++++++ candle-pyo3/src/lib.rs | 73 ++++++++++++++++++- candle-pyo3/tests/bindings/test_testing.py | 33 +++++++++ candle-pyo3/tests/native/test_tensor.py | 73 +++++++++++++++++++ 7 files changed, 325 insertions(+), 3 deletions(-) create mode 100644 candle-pyo3/py_src/candle/testing/__init__.py create mode 100644 candle-pyo3/tests/bindings/test_testing.py diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs index ac00a979..beaa9455 100644 --- a/candle-core/src/shape.rs +++ b/candle-core/src/shape.rs @@ -203,7 +203,7 @@ impl Shape { /// Check whether the two shapes are compatible for broadcast, and if it is the case return the /// broadcasted shape. This is to be used for binary pointwise ops. - pub(crate) fn broadcast_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result { + pub fn broadcast_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result { let lhs = self; let lhs_dims = lhs.dims(); let rhs_dims = rhs.dims(); diff --git a/candle-pyo3/_additional_typing/__init__.py b/candle-pyo3/_additional_typing/__init__.py index 0d0eec90..7bc17ee1 100644 --- a/candle-pyo3/_additional_typing/__init__.py +++ b/candle-pyo3/_additional_typing/__init__.py @@ -53,3 +53,39 @@ class Tensor: Return a slice of a tensor. """ pass + + def __eq__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor": + """ + Compare a tensor with a scalar or one tensor with another. + """ + pass + + def __ne__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor": + """ + Compare a tensor with a scalar or one tensor with another. + """ + pass + + def __lt__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor": + """ + Compare a tensor with a scalar or one tensor with another. + """ + pass + + def __le__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor": + """ + Compare a tensor with a scalar or one tensor with another. + """ + pass + + def __gt__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor": + """ + Compare a tensor with a scalar or one tensor with another. + """ + pass + + def __ge__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor": + """ + Compare a tensor with a scalar or one tensor with another. + """ + pass diff --git a/candle-pyo3/py_src/candle/__init__.pyi b/candle-pyo3/py_src/candle/__init__.pyi index 35b17680..37b8fe8c 100644 --- a/candle-pyo3/py_src/candle/__init__.pyi +++ b/candle-pyo3/py_src/candle/__init__.pyi @@ -124,16 +124,46 @@ class Tensor: Add a scalar to a tensor or two tensors together. """ pass + def __eq__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": + """ + Compare a tensor with a scalar or one tensor with another. + """ + pass + def __ge__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": + """ + Compare a tensor with a scalar or one tensor with another. + """ + pass def __getitem__(self, index: Union[Index, Tensor, Sequence[Index]]) -> "Tensor": """ Return a slice of a tensor. """ pass + def __gt__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": + """ + Compare a tensor with a scalar or one tensor with another. + """ + pass + def __le__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": + """ + Compare a tensor with a scalar or one tensor with another. + """ + pass + def __lt__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": + """ + Compare a tensor with a scalar or one tensor with another. + """ + pass def __mul__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": """ Multiply a tensor by a scalar or one tensor by another. """ pass + def __ne__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": + """ + Compare a tensor with a scalar or one tensor with another. + """ + pass def __radd__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": """ Add a scalar to a tensor or two tensors together. @@ -159,6 +189,11 @@ class Tensor: Divide a tensor by a scalar or one tensor by another. """ pass + def abs(self) -> Tensor: + """ + Performs the `abs` operation on the tensor. + """ + pass def argmax_keepdim(self, dim: int) -> Tensor: """ Returns the indices of the maximum value(s) across the selected dimension. @@ -308,6 +343,12 @@ class Tensor: ranges from `start` to `start + len`. """ pass + @property + def nelement(self) -> int: + """ + Gets the tensor's element count. + """ + pass def powf(self, p: float) -> Tensor: """ Performs the `pow` operation on the tensor with the given exponent. diff --git a/candle-pyo3/py_src/candle/testing/__init__.py b/candle-pyo3/py_src/candle/testing/__init__.py new file mode 100644 index 00000000..240b635f --- /dev/null +++ b/candle-pyo3/py_src/candle/testing/__init__.py @@ -0,0 +1,70 @@ +import candle +from candle import Tensor + + +_UNSIGNED_DTYPES = set([str(candle.u8), str(candle.u32)]) + + +def _assert_tensor_metadata( + actual: Tensor, + expected: Tensor, + check_device: bool = True, + check_dtype: bool = True, + check_layout: bool = True, + check_stride: bool = False, +): + if check_device: + assert actual.device == expected.device, f"Device mismatch: {actual.device} != {expected.device}" + + if check_dtype: + assert str(actual.dtype) == str(expected.dtype), f"Dtype mismatch: {actual.dtype} != {expected.dtype}" + + if check_layout: + assert actual.shape == expected.shape, f"Shape mismatch: {actual.shape} != {expected.shape}" + + if check_stride: + assert actual.stride == expected.stride, f"Stride mismatch: {actual.stride} != {expected.stride}" + + +def assert_equal( + actual: Tensor, + expected: Tensor, + check_device: bool = True, + check_dtype: bool = True, + check_layout: bool = True, + check_stride: bool = False, +): + """ + Asserts that two tensors are exact equals. + """ + _assert_tensor_metadata(actual, expected, check_device, check_dtype, check_layout, check_stride) + assert (actual - expected).abs().sum_all().values() == 0, f"Tensors mismatch: {actual} != {expected}" + + +def assert_almost_equal( + actual: Tensor, + expected: Tensor, + rtol=1e-05, + atol=1e-08, + check_device: bool = True, + check_dtype: bool = True, + check_layout: bool = True, + check_stride: bool = False, +): + """ + Asserts, that two tensors are almost equal by performing an element wise comparison of the tensors with a tolerance. + + Computes: |actual - expected| ≤ atol + rtol x |expected| + """ + _assert_tensor_metadata(actual, expected, check_device, check_dtype, check_layout, check_stride) + + # Secure against overflow of u32 and u8 tensors + if str(actual.dtype) in _UNSIGNED_DTYPES or str(expected.dtype) in _UNSIGNED_DTYPES: + actual = actual.to(candle.i64) + expected = expected.to(candle.i64) + + diff = (actual - expected).abs() + + threshold = (expected.abs().to_dtype(candle.f32) * rtol + atol).to(expected) + + assert (diff <= threshold).sum_all().values() == actual.nelement, f"Difference between tensors was to great" diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 41c4577f..ddd58fbe 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -1,8 +1,11 @@ #![allow(clippy::redundant_closure_call)] use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::prelude::*; +use pyo3::pyclass::CompareOp; use pyo3::types::{IntoPyDict, PyDict, PyTuple}; use pyo3::ToPyObject; +use std::collections::hash_map::DefaultHasher; +use std::hash::{Hash, Hasher}; use std::os::raw::c_long; use std::sync::Arc; @@ -132,9 +135,10 @@ macro_rules! pydtype { } }; } + +pydtype!(i64, |v| v); pydtype!(u8, |v| v); pydtype!(u32, |v| v); -pydtype!(i64, |v| v); pydtype!(f16, f32::from); pydtype!(bf16, f32::from); pydtype!(f32, |v| v); @@ -317,6 +321,13 @@ impl PyTensor { PyTuple::new(py, self.0.dims()).to_object(py) } + #[getter] + /// Gets the tensor's element count. + /// &RETURNS&: int + fn nelement(&self) -> usize { + self.0.elem_count() + } + #[getter] /// Gets the tensor's strides. /// &RETURNS&: Tuple[int] @@ -353,6 +364,12 @@ impl PyTensor { self.__repr__() } + /// Performs the `abs` operation on the tensor. + /// &RETURNS&: Tensor + fn abs(&self) -> PyResult { + Ok(PyTensor(self.0.abs().map_err(wrap_err)?)) + } + /// Performs the `sin` operation on the tensor. /// &RETURNS&: Tensor fn sin(&self) -> PyResult { @@ -670,6 +687,58 @@ impl PyTensor { }; Ok(Self(tensor)) } + /// Rich-compare two tensors. + /// &RETURNS&: Tensor + fn __richcmp__(&self, rhs: &PyAny, op: CompareOp) -> PyResult { + let compare = |lhs: &Tensor, rhs: &Tensor| { + let t = match op { + CompareOp::Eq => lhs.eq(rhs), + CompareOp::Ne => lhs.ne(rhs), + CompareOp::Lt => lhs.lt(rhs), + CompareOp::Le => lhs.le(rhs), + CompareOp::Gt => lhs.gt(rhs), + CompareOp::Ge => lhs.ge(rhs), + }; + Ok(PyTensor(t.map_err(wrap_err)?)) + }; + if let Ok(rhs) = rhs.extract::() { + if self.0.shape() == rhs.0.shape() { + compare(&self.0, &rhs.0) + } else { + // We broadcast manually here because `candle.cmp` does not support automatic broadcasting + let broadcast_shape = self + .0 + .shape() + .broadcast_shape_binary_op(rhs.0.shape(), "cmp") + .map_err(wrap_err)?; + let broadcasted_lhs = self.0.broadcast_as(&broadcast_shape).map_err(wrap_err)?; + let broadcasted_rhs = rhs.0.broadcast_as(&broadcast_shape).map_err(wrap_err)?; + + compare(&broadcasted_lhs, &broadcasted_rhs) + } + } else if let Ok(rhs) = rhs.extract::() { + let scalar_tensor = Tensor::new(rhs, self.0.device()) + .map_err(wrap_err)? + .to_dtype(self.0.dtype()) + .map_err(wrap_err)? + .broadcast_as(self.0.shape()) + .map_err(wrap_err)?; + + compare(&self.0, &scalar_tensor) + } else { + return Err(PyTypeError::new_err("unsupported rhs for __richcmp__")); + } + } + + fn __hash__(&self) -> u64 { + // we have overridden __richcmp__ => py03 wants us to also override __hash__ + // we simply hash the address of the tensor + let mut hasher = DefaultHasher::new(); + let pointer = &self.0 as *const Tensor; + let address = pointer as usize; + address.hash(&mut hasher); + hasher.finish() + } #[pyo3(signature=(*shape), text_signature = "(self, *shape:Shape)")] /// Reshapes the tensor to the given shape. @@ -1503,7 +1572,7 @@ fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add("u8", PyDType(DType::U8))?; m.add("u32", PyDType(DType::U32))?; - m.add("i16", PyDType(DType::I64))?; + m.add("i64", PyDType(DType::I64))?; m.add("bf16", PyDType(DType::BF16))?; m.add("f16", PyDType(DType::F16))?; m.add("f32", PyDType(DType::F32))?; diff --git a/candle-pyo3/tests/bindings/test_testing.py b/candle-pyo3/tests/bindings/test_testing.py new file mode 100644 index 00000000..db2fd3f7 --- /dev/null +++ b/candle-pyo3/tests/bindings/test_testing.py @@ -0,0 +1,33 @@ +import candle +from candle import Tensor +from candle.testing import assert_equal, assert_almost_equal +import pytest + + +@pytest.mark.parametrize("dtype", [candle.f32, candle.f64, candle.f16, candle.u32, candle.u8, candle.i64]) +def test_assert_equal_asserts_correctly(dtype: candle.DType): + a = Tensor([1, 2, 3]).to(dtype) + b = Tensor([1, 2, 3]).to(dtype) + assert_equal(a, b) + + with pytest.raises(AssertionError): + assert_equal(a, b + 1) + + +@pytest.mark.parametrize("dtype", [candle.f32, candle.f64, candle.f16, candle.u32, candle.u8, candle.i64]) +def test_assert_almost_equal_asserts_correctly(dtype: candle.DType): + a = Tensor([1, 2, 3]).to(dtype) + b = Tensor([1, 2, 3]).to(dtype) + assert_almost_equal(a, b) + + with pytest.raises(AssertionError): + assert_almost_equal(a, b + 1) + + assert_almost_equal(a, b + 1, atol=20) + assert_almost_equal(a, b + 1, rtol=20) + + with pytest.raises(AssertionError): + assert_almost_equal(a, b + 1, atol=0.9) + + with pytest.raises(AssertionError): + assert_almost_equal(a, b + 1, rtol=0.1) diff --git a/candle-pyo3/tests/native/test_tensor.py b/candle-pyo3/tests/native/test_tensor.py index e4cf19f1..ef44fc4c 100644 --- a/candle-pyo3/tests/native/test_tensor.py +++ b/candle-pyo3/tests/native/test_tensor.py @@ -1,6 +1,7 @@ import candle from candle import Tensor from candle.utils import cuda_is_available +from candle.testing import assert_equal import pytest @@ -77,6 +78,78 @@ def test_tensor_can_be_scliced_3d(): assert t[..., 0:2].values() == [[[1, 2], [5, 6]], [[9, 10], [13, 14]]] +def assert_bool(t: Tensor, expected: bool): + assert t.shape == () + assert str(t.dtype) == str(candle.u8) + assert bool(t.values()) == expected + + +def test_tensor_supports_equality_opperations_with_scalars(): + t = Tensor(42.0) + + assert_bool(t == 42.0, True) + assert_bool(t == 43.0, False) + + assert_bool(t != 42.0, False) + assert_bool(t != 43.0, True) + + assert_bool(t > 41.0, True) + assert_bool(t > 42.0, False) + + assert_bool(t >= 41.0, True) + assert_bool(t >= 42.0, True) + + assert_bool(t < 43.0, True) + assert_bool(t < 42.0, False) + + assert_bool(t <= 43.0, True) + assert_bool(t <= 42.0, True) + + +def test_tensor_supports_equality_opperations_with_tensors(): + t = Tensor(42.0) + same = Tensor(42.0) + other = Tensor(43.0) + + assert_bool(t == same, True) + assert_bool(t == other, False) + + assert_bool(t != same, False) + assert_bool(t != other, True) + + assert_bool(t > same, False) + assert_bool(t > other, False) + + assert_bool(t >= same, True) + assert_bool(t >= other, False) + + assert_bool(t < same, False) + assert_bool(t < other, True) + + assert_bool(t <= same, True) + assert_bool(t <= other, True) + + +def test_tensor_equality_opperations_can_broadcast(): + # Create a decoder attention mask as a test case + # e.g. + # [[1,0,0] + # [1,1,0] + # [1,1,1]] + mask_cond = candle.Tensor([0, 1, 2]) + mask = mask_cond < (mask_cond + 1).reshape((3, 1)) + assert mask.shape == (3, 3) + assert_equal(mask, Tensor([[1, 0, 0], [1, 1, 0], [1, 1, 1]]).to_dtype(candle.u8)) + + +def test_tensor_can_be_hashed(): + t = Tensor(42.0) + other = Tensor(42.0) + # Hash should represent a unique tensor + assert hash(t) != hash(other) + assert hash(t) == hash(t) + + def test_tensor_can_be_expanded_with_none(): t = candle.rand((12, 12))