convert pytorch's tensor in Python API (#1172)
* convert pytorch's tensor * separate tests for convert pytorch tensor
This commit is contained in:
parent
0acd16751d
commit
6a446d9d73
|
@ -396,6 +396,11 @@ class Tensor:
|
|||
Convert the tensor to a new dtype.
|
||||
"""
|
||||
pass
|
||||
def to_torch(self) -> torch.Tensor:
|
||||
"""
|
||||
Converts candle's tensor to pytorch's tensor
|
||||
"""
|
||||
pass
|
||||
def transpose(self, dim1: int, dim2: int) -> Tensor:
|
||||
"""
|
||||
Returns a tensor that is a transposed version of the input, the given dimensions are swapped.
|
||||
|
|
|
@ -211,6 +211,16 @@ enum Indexer {
|
|||
IndexSelect(Tensor),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct TorchTensor(PyObject);
|
||||
|
||||
impl<'source> pyo3::FromPyObject<'source> for TorchTensor {
|
||||
fn extract(ob: &'source PyAny) -> PyResult<Self> {
|
||||
let numpy_value: PyObject = ob.getattr("numpy")?.call0()?.extract()?;
|
||||
Ok(TorchTensor(numpy_value))
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl PyTensor {
|
||||
#[new]
|
||||
|
@ -246,6 +256,8 @@ impl PyTensor {
|
|||
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
||||
} else if let Ok(vs) = data.extract::<Vec<Vec<Vec<f32>>>>(py) {
|
||||
Tensor::new(vs, &Cpu).map_err(wrap_err)?
|
||||
} else if let Ok(TorchTensor(numpy)) = data.extract::<TorchTensor>(py) {
|
||||
return PyTensor::new(py, numpy);
|
||||
} else {
|
||||
let ty = data.as_ref(py).get_type();
|
||||
Err(PyTypeError::new_err(format!(
|
||||
|
@ -299,6 +311,18 @@ impl PyTensor {
|
|||
M(py).map(self)
|
||||
}
|
||||
|
||||
/// Converts candle's tensor to pytorch's tensor
|
||||
/// &RETURNS&: torch.Tensor
|
||||
fn to_torch(&self, py: Python<'_>) -> PyResult<PyObject> {
|
||||
let candle_values = self.values(py)?;
|
||||
let torch_tensor: PyObject = py
|
||||
.import("torch")?
|
||||
.getattr("tensor")?
|
||||
.call1((candle_values,))?
|
||||
.extract()?;
|
||||
Ok(torch_tensor)
|
||||
}
|
||||
|
||||
#[getter]
|
||||
/// Gets the tensor's shape.
|
||||
/// &RETURNS&: Tuple[int]
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
import candle
|
||||
import torch
|
||||
|
||||
# convert from candle tensor to torch tensor
|
||||
t = candle.randn((3, 512, 512))
|
||||
torch_tensor = t.to_torch()
|
||||
print(torch_tensor)
|
||||
print(type(torch_tensor))
|
||||
|
||||
# convert from torch tensor to candle tensor
|
||||
t = torch.randn((3, 512, 512))
|
||||
candle_tensor = candle.Tensor(t)
|
||||
print(candle_tensor)
|
||||
print(type(candle_tensor))
|
Loading…
Reference in New Issue