convert pytorch's tensor in Python API (#1172)

* convert pytorch's tensor

* separate tests for convert pytorch tensor
This commit is contained in:
andrew 2023-10-26 01:39:14 +07:00 committed by GitHub
parent 0acd16751d
commit 6a446d9d73
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 43 additions and 0 deletions

View File

@ -396,6 +396,11 @@ class Tensor:
Convert the tensor to a new dtype.
"""
pass
def to_torch(self) -> torch.Tensor:
"""
Converts candle's tensor to pytorch's tensor
"""
pass
def transpose(self, dim1: int, dim2: int) -> Tensor:
"""
Returns a tensor that is a transposed version of the input, the given dimensions are swapped.

View File

@ -211,6 +211,16 @@ enum Indexer {
IndexSelect(Tensor),
}
#[derive(Clone, Debug)]
struct TorchTensor(PyObject);
impl<'source> pyo3::FromPyObject<'source> for TorchTensor {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
let numpy_value: PyObject = ob.getattr("numpy")?.call0()?.extract()?;
Ok(TorchTensor(numpy_value))
}
}
#[pymethods]
impl PyTensor {
#[new]
@ -246,6 +256,8 @@ impl PyTensor {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
} else if let Ok(vs) = data.extract::<Vec<Vec<Vec<f32>>>>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
} else if let Ok(TorchTensor(numpy)) = data.extract::<TorchTensor>(py) {
return PyTensor::new(py, numpy);
} else {
let ty = data.as_ref(py).get_type();
Err(PyTypeError::new_err(format!(
@ -299,6 +311,18 @@ impl PyTensor {
M(py).map(self)
}
/// Converts candle's tensor to pytorch's tensor
/// &RETURNS&: torch.Tensor
fn to_torch(&self, py: Python<'_>) -> PyResult<PyObject> {
let candle_values = self.values(py)?;
let torch_tensor: PyObject = py
.import("torch")?
.getattr("tensor")?
.call1((candle_values,))?
.extract()?;
Ok(torch_tensor)
}
#[getter]
/// Gets the tensor's shape.
/// &RETURNS&: Tuple[int]

View File

@ -0,0 +1,14 @@
import candle
import torch
# convert from candle tensor to torch tensor
t = candle.randn((3, 512, 512))
torch_tensor = t.to_torch()
print(torch_tensor)
print(type(torch_tensor))
# convert from torch tensor to candle tensor
t = torch.randn((3, 512, 512))
candle_tensor = candle.Tensor(t)
print(candle_tensor)
print(type(candle_tensor))