PyO3: Add `equal` and `__richcmp__` to `candle.Tensor` (#1099)
* add `equal` to tensor * add `__richcmp__` support for tensors and scalars * typo * more typos * Add `abs` + `candle.testing` * remove duplicated `broadcast_shape_binary_op` * `candle.i16` => `candle.i64` * `tensor.nelements` -> `tensor.nelement` * Cleanup `abs`
This commit is contained in:
parent
969960847a
commit
c05c0a8213
|
@ -203,7 +203,7 @@ impl Shape {
|
|||
|
||||
/// Check whether the two shapes are compatible for broadcast, and if it is the case return the
|
||||
/// broadcasted shape. This is to be used for binary pointwise ops.
|
||||
pub(crate) fn broadcast_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<Shape> {
|
||||
pub fn broadcast_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<Shape> {
|
||||
let lhs = self;
|
||||
let lhs_dims = lhs.dims();
|
||||
let rhs_dims = rhs.dims();
|
||||
|
|
|
@ -53,3 +53,39 @@ class Tensor:
|
|||
Return a slice of a tensor.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __eq__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
|
||||
"""
|
||||
Compare a tensor with a scalar or one tensor with another.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __ne__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
|
||||
"""
|
||||
Compare a tensor with a scalar or one tensor with another.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __lt__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
|
||||
"""
|
||||
Compare a tensor with a scalar or one tensor with another.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __le__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
|
||||
"""
|
||||
Compare a tensor with a scalar or one tensor with another.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __gt__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
|
||||
"""
|
||||
Compare a tensor with a scalar or one tensor with another.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __ge__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
|
||||
"""
|
||||
Compare a tensor with a scalar or one tensor with another.
|
||||
"""
|
||||
pass
|
||||
|
|
|
@ -124,16 +124,46 @@ class Tensor:
|
|||
Add a scalar to a tensor or two tensors together.
|
||||
"""
|
||||
pass
|
||||
def __eq__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||
"""
|
||||
Compare a tensor with a scalar or one tensor with another.
|
||||
"""
|
||||
pass
|
||||
def __ge__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||
"""
|
||||
Compare a tensor with a scalar or one tensor with another.
|
||||
"""
|
||||
pass
|
||||
def __getitem__(self, index: Union[Index, Tensor, Sequence[Index]]) -> "Tensor":
|
||||
"""
|
||||
Return a slice of a tensor.
|
||||
"""
|
||||
pass
|
||||
def __gt__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||
"""
|
||||
Compare a tensor with a scalar or one tensor with another.
|
||||
"""
|
||||
pass
|
||||
def __le__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||
"""
|
||||
Compare a tensor with a scalar or one tensor with another.
|
||||
"""
|
||||
pass
|
||||
def __lt__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||
"""
|
||||
Compare a tensor with a scalar or one tensor with another.
|
||||
"""
|
||||
pass
|
||||
def __mul__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||
"""
|
||||
Multiply a tensor by a scalar or one tensor by another.
|
||||
"""
|
||||
pass
|
||||
def __ne__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||
"""
|
||||
Compare a tensor with a scalar or one tensor with another.
|
||||
"""
|
||||
pass
|
||||
def __radd__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||
"""
|
||||
Add a scalar to a tensor or two tensors together.
|
||||
|
@ -159,6 +189,11 @@ class Tensor:
|
|||
Divide a tensor by a scalar or one tensor by another.
|
||||
"""
|
||||
pass
|
||||
def abs(self) -> Tensor:
|
||||
"""
|
||||
Performs the `abs` operation on the tensor.
|
||||
"""
|
||||
pass
|
||||
def argmax_keepdim(self, dim: int) -> Tensor:
|
||||
"""
|
||||
Returns the indices of the maximum value(s) across the selected dimension.
|
||||
|
@ -308,6 +343,12 @@ class Tensor:
|
|||
ranges from `start` to `start + len`.
|
||||
"""
|
||||
pass
|
||||
@property
|
||||
def nelement(self) -> int:
|
||||
"""
|
||||
Gets the tensor's element count.
|
||||
"""
|
||||
pass
|
||||
def powf(self, p: float) -> Tensor:
|
||||
"""
|
||||
Performs the `pow` operation on the tensor with the given exponent.
|
||||
|
|
|
@ -0,0 +1,70 @@
|
|||
import candle
|
||||
from candle import Tensor
|
||||
|
||||
|
||||
_UNSIGNED_DTYPES = set([str(candle.u8), str(candle.u32)])
|
||||
|
||||
|
||||
def _assert_tensor_metadata(
|
||||
actual: Tensor,
|
||||
expected: Tensor,
|
||||
check_device: bool = True,
|
||||
check_dtype: bool = True,
|
||||
check_layout: bool = True,
|
||||
check_stride: bool = False,
|
||||
):
|
||||
if check_device:
|
||||
assert actual.device == expected.device, f"Device mismatch: {actual.device} != {expected.device}"
|
||||
|
||||
if check_dtype:
|
||||
assert str(actual.dtype) == str(expected.dtype), f"Dtype mismatch: {actual.dtype} != {expected.dtype}"
|
||||
|
||||
if check_layout:
|
||||
assert actual.shape == expected.shape, f"Shape mismatch: {actual.shape} != {expected.shape}"
|
||||
|
||||
if check_stride:
|
||||
assert actual.stride == expected.stride, f"Stride mismatch: {actual.stride} != {expected.stride}"
|
||||
|
||||
|
||||
def assert_equal(
|
||||
actual: Tensor,
|
||||
expected: Tensor,
|
||||
check_device: bool = True,
|
||||
check_dtype: bool = True,
|
||||
check_layout: bool = True,
|
||||
check_stride: bool = False,
|
||||
):
|
||||
"""
|
||||
Asserts that two tensors are exact equals.
|
||||
"""
|
||||
_assert_tensor_metadata(actual, expected, check_device, check_dtype, check_layout, check_stride)
|
||||
assert (actual - expected).abs().sum_all().values() == 0, f"Tensors mismatch: {actual} != {expected}"
|
||||
|
||||
|
||||
def assert_almost_equal(
|
||||
actual: Tensor,
|
||||
expected: Tensor,
|
||||
rtol=1e-05,
|
||||
atol=1e-08,
|
||||
check_device: bool = True,
|
||||
check_dtype: bool = True,
|
||||
check_layout: bool = True,
|
||||
check_stride: bool = False,
|
||||
):
|
||||
"""
|
||||
Asserts, that two tensors are almost equal by performing an element wise comparison of the tensors with a tolerance.
|
||||
|
||||
Computes: |actual - expected| ≤ atol + rtol x |expected|
|
||||
"""
|
||||
_assert_tensor_metadata(actual, expected, check_device, check_dtype, check_layout, check_stride)
|
||||
|
||||
# Secure against overflow of u32 and u8 tensors
|
||||
if str(actual.dtype) in _UNSIGNED_DTYPES or str(expected.dtype) in _UNSIGNED_DTYPES:
|
||||
actual = actual.to(candle.i64)
|
||||
expected = expected.to(candle.i64)
|
||||
|
||||
diff = (actual - expected).abs()
|
||||
|
||||
threshold = (expected.abs().to_dtype(candle.f32) * rtol + atol).to(expected)
|
||||
|
||||
assert (diff <= threshold).sum_all().values() == actual.nelement, f"Difference between tensors was to great"
|
|
@ -1,8 +1,11 @@
|
|||
#![allow(clippy::redundant_closure_call)]
|
||||
use pyo3::exceptions::{PyTypeError, PyValueError};
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::pyclass::CompareOp;
|
||||
use pyo3::types::{IntoPyDict, PyDict, PyTuple};
|
||||
use pyo3::ToPyObject;
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::os::raw::c_long;
|
||||
use std::sync::Arc;
|
||||
|
||||
|
@ -132,9 +135,10 @@ macro_rules! pydtype {
|
|||
}
|
||||
};
|
||||
}
|
||||
|
||||
pydtype!(i64, |v| v);
|
||||
pydtype!(u8, |v| v);
|
||||
pydtype!(u32, |v| v);
|
||||
pydtype!(i64, |v| v);
|
||||
pydtype!(f16, f32::from);
|
||||
pydtype!(bf16, f32::from);
|
||||
pydtype!(f32, |v| v);
|
||||
|
@ -317,6 +321,13 @@ impl PyTensor {
|
|||
PyTuple::new(py, self.0.dims()).to_object(py)
|
||||
}
|
||||
|
||||
#[getter]
|
||||
/// Gets the tensor's element count.
|
||||
/// &RETURNS&: int
|
||||
fn nelement(&self) -> usize {
|
||||
self.0.elem_count()
|
||||
}
|
||||
|
||||
#[getter]
|
||||
/// Gets the tensor's strides.
|
||||
/// &RETURNS&: Tuple[int]
|
||||
|
@ -353,6 +364,12 @@ impl PyTensor {
|
|||
self.__repr__()
|
||||
}
|
||||
|
||||
/// Performs the `abs` operation on the tensor.
|
||||
/// &RETURNS&: Tensor
|
||||
fn abs(&self) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.abs().map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
/// Performs the `sin` operation on the tensor.
|
||||
/// &RETURNS&: Tensor
|
||||
fn sin(&self) -> PyResult<Self> {
|
||||
|
@ -670,6 +687,58 @@ impl PyTensor {
|
|||
};
|
||||
Ok(Self(tensor))
|
||||
}
|
||||
/// Rich-compare two tensors.
|
||||
/// &RETURNS&: Tensor
|
||||
fn __richcmp__(&self, rhs: &PyAny, op: CompareOp) -> PyResult<Self> {
|
||||
let compare = |lhs: &Tensor, rhs: &Tensor| {
|
||||
let t = match op {
|
||||
CompareOp::Eq => lhs.eq(rhs),
|
||||
CompareOp::Ne => lhs.ne(rhs),
|
||||
CompareOp::Lt => lhs.lt(rhs),
|
||||
CompareOp::Le => lhs.le(rhs),
|
||||
CompareOp::Gt => lhs.gt(rhs),
|
||||
CompareOp::Ge => lhs.ge(rhs),
|
||||
};
|
||||
Ok(PyTensor(t.map_err(wrap_err)?))
|
||||
};
|
||||
if let Ok(rhs) = rhs.extract::<PyTensor>() {
|
||||
if self.0.shape() == rhs.0.shape() {
|
||||
compare(&self.0, &rhs.0)
|
||||
} else {
|
||||
// We broadcast manually here because `candle.cmp` does not support automatic broadcasting
|
||||
let broadcast_shape = self
|
||||
.0
|
||||
.shape()
|
||||
.broadcast_shape_binary_op(rhs.0.shape(), "cmp")
|
||||
.map_err(wrap_err)?;
|
||||
let broadcasted_lhs = self.0.broadcast_as(&broadcast_shape).map_err(wrap_err)?;
|
||||
let broadcasted_rhs = rhs.0.broadcast_as(&broadcast_shape).map_err(wrap_err)?;
|
||||
|
||||
compare(&broadcasted_lhs, &broadcasted_rhs)
|
||||
}
|
||||
} else if let Ok(rhs) = rhs.extract::<f64>() {
|
||||
let scalar_tensor = Tensor::new(rhs, self.0.device())
|
||||
.map_err(wrap_err)?
|
||||
.to_dtype(self.0.dtype())
|
||||
.map_err(wrap_err)?
|
||||
.broadcast_as(self.0.shape())
|
||||
.map_err(wrap_err)?;
|
||||
|
||||
compare(&self.0, &scalar_tensor)
|
||||
} else {
|
||||
return Err(PyTypeError::new_err("unsupported rhs for __richcmp__"));
|
||||
}
|
||||
}
|
||||
|
||||
fn __hash__(&self) -> u64 {
|
||||
// we have overridden __richcmp__ => py03 wants us to also override __hash__
|
||||
// we simply hash the address of the tensor
|
||||
let mut hasher = DefaultHasher::new();
|
||||
let pointer = &self.0 as *const Tensor;
|
||||
let address = pointer as usize;
|
||||
address.hash(&mut hasher);
|
||||
hasher.finish()
|
||||
}
|
||||
|
||||
#[pyo3(signature=(*shape), text_signature = "(self, *shape:Shape)")]
|
||||
/// Reshapes the tensor to the given shape.
|
||||
|
@ -1503,7 +1572,7 @@ fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
|||
m.add_class::<PyDType>()?;
|
||||
m.add("u8", PyDType(DType::U8))?;
|
||||
m.add("u32", PyDType(DType::U32))?;
|
||||
m.add("i16", PyDType(DType::I64))?;
|
||||
m.add("i64", PyDType(DType::I64))?;
|
||||
m.add("bf16", PyDType(DType::BF16))?;
|
||||
m.add("f16", PyDType(DType::F16))?;
|
||||
m.add("f32", PyDType(DType::F32))?;
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
import candle
|
||||
from candle import Tensor
|
||||
from candle.testing import assert_equal, assert_almost_equal
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [candle.f32, candle.f64, candle.f16, candle.u32, candle.u8, candle.i64])
|
||||
def test_assert_equal_asserts_correctly(dtype: candle.DType):
|
||||
a = Tensor([1, 2, 3]).to(dtype)
|
||||
b = Tensor([1, 2, 3]).to(dtype)
|
||||
assert_equal(a, b)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
assert_equal(a, b + 1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [candle.f32, candle.f64, candle.f16, candle.u32, candle.u8, candle.i64])
|
||||
def test_assert_almost_equal_asserts_correctly(dtype: candle.DType):
|
||||
a = Tensor([1, 2, 3]).to(dtype)
|
||||
b = Tensor([1, 2, 3]).to(dtype)
|
||||
assert_almost_equal(a, b)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
assert_almost_equal(a, b + 1)
|
||||
|
||||
assert_almost_equal(a, b + 1, atol=20)
|
||||
assert_almost_equal(a, b + 1, rtol=20)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
assert_almost_equal(a, b + 1, atol=0.9)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
assert_almost_equal(a, b + 1, rtol=0.1)
|
|
@ -1,6 +1,7 @@
|
|||
import candle
|
||||
from candle import Tensor
|
||||
from candle.utils import cuda_is_available
|
||||
from candle.testing import assert_equal
|
||||
import pytest
|
||||
|
||||
|
||||
|
@ -77,6 +78,78 @@ def test_tensor_can_be_scliced_3d():
|
|||
assert t[..., 0:2].values() == [[[1, 2], [5, 6]], [[9, 10], [13, 14]]]
|
||||
|
||||
|
||||
def assert_bool(t: Tensor, expected: bool):
|
||||
assert t.shape == ()
|
||||
assert str(t.dtype) == str(candle.u8)
|
||||
assert bool(t.values()) == expected
|
||||
|
||||
|
||||
def test_tensor_supports_equality_opperations_with_scalars():
|
||||
t = Tensor(42.0)
|
||||
|
||||
assert_bool(t == 42.0, True)
|
||||
assert_bool(t == 43.0, False)
|
||||
|
||||
assert_bool(t != 42.0, False)
|
||||
assert_bool(t != 43.0, True)
|
||||
|
||||
assert_bool(t > 41.0, True)
|
||||
assert_bool(t > 42.0, False)
|
||||
|
||||
assert_bool(t >= 41.0, True)
|
||||
assert_bool(t >= 42.0, True)
|
||||
|
||||
assert_bool(t < 43.0, True)
|
||||
assert_bool(t < 42.0, False)
|
||||
|
||||
assert_bool(t <= 43.0, True)
|
||||
assert_bool(t <= 42.0, True)
|
||||
|
||||
|
||||
def test_tensor_supports_equality_opperations_with_tensors():
|
||||
t = Tensor(42.0)
|
||||
same = Tensor(42.0)
|
||||
other = Tensor(43.0)
|
||||
|
||||
assert_bool(t == same, True)
|
||||
assert_bool(t == other, False)
|
||||
|
||||
assert_bool(t != same, False)
|
||||
assert_bool(t != other, True)
|
||||
|
||||
assert_bool(t > same, False)
|
||||
assert_bool(t > other, False)
|
||||
|
||||
assert_bool(t >= same, True)
|
||||
assert_bool(t >= other, False)
|
||||
|
||||
assert_bool(t < same, False)
|
||||
assert_bool(t < other, True)
|
||||
|
||||
assert_bool(t <= same, True)
|
||||
assert_bool(t <= other, True)
|
||||
|
||||
|
||||
def test_tensor_equality_opperations_can_broadcast():
|
||||
# Create a decoder attention mask as a test case
|
||||
# e.g.
|
||||
# [[1,0,0]
|
||||
# [1,1,0]
|
||||
# [1,1,1]]
|
||||
mask_cond = candle.Tensor([0, 1, 2])
|
||||
mask = mask_cond < (mask_cond + 1).reshape((3, 1))
|
||||
assert mask.shape == (3, 3)
|
||||
assert_equal(mask, Tensor([[1, 0, 0], [1, 1, 0], [1, 1, 1]]).to_dtype(candle.u8))
|
||||
|
||||
|
||||
def test_tensor_can_be_hashed():
|
||||
t = Tensor(42.0)
|
||||
other = Tensor(42.0)
|
||||
# Hash should represent a unique tensor
|
||||
assert hash(t) != hash(other)
|
||||
assert hash(t) == hash(t)
|
||||
|
||||
|
||||
def test_tensor_can_be_expanded_with_none():
|
||||
t = candle.rand((12, 12))
|
||||
|
||||
|
|
Loading…
Reference in New Issue