forked from mindspore-Ecosystem/mindspore
Fix problems in Tensor.from_numpy()
Use buffer_info instead of py::array in TensorDataNumpy, to avoid requesting buffer every time the data() is called. This is also a workaround for py::array.request() hangup problem.
This commit is contained in:
parent
d5e02cf474
commit
cbfd4c5f40
|
@ -120,51 +120,24 @@ static bool IsCContiguous(const py::array &input) {
|
|||
// TensorDataNumpy implements TensorData using numpy array.
|
||||
class TensorDataNumpy : public TensorData {
|
||||
public:
|
||||
explicit TensorDataNumpy(const py::array &input) : data_(input) {
|
||||
if (!IsCContiguous(data_)) {
|
||||
// Call numpy.ascontiguousarray() to convert data to C contiguous if it is not.
|
||||
auto np = py::module::import("numpy");
|
||||
auto convert = np.attr("ascontiguousarray");
|
||||
data_ = convert(data_);
|
||||
}
|
||||
}
|
||||
explicit TensorDataNumpy(py::buffer_info &&buffer) : buffer_(std::move(buffer)) {}
|
||||
|
||||
/// Total number of elements.
|
||||
ssize_t size() const override { return data_.size(); }
|
||||
ssize_t size() const override { return buffer_.size; }
|
||||
|
||||
/// Byte size of a single element.
|
||||
ssize_t itemsize() const override { return data_.itemsize(); }
|
||||
ssize_t itemsize() const override { return buffer_.itemsize; }
|
||||
|
||||
/// Total number of bytes.
|
||||
ssize_t nbytes() const override { return data_.nbytes(); }
|
||||
ssize_t nbytes() const override { return buffer_.itemsize * buffer_.size; }
|
||||
|
||||
/// Number of dimensions.
|
||||
ssize_t ndim() const override { return data_.ndim(); }
|
||||
ssize_t ndim() const override { return buffer_.ndim; }
|
||||
|
||||
/// Data pointer.
|
||||
void *data() override { return data_.request().ptr; }
|
||||
void *data() override { return buffer_.ptr; }
|
||||
|
||||
const void *const_data() const override { return data_.request().ptr; }
|
||||
|
||||
/// Is data equals.
|
||||
bool equals(const TensorData &other) const override {
|
||||
auto ptr = dynamic_cast<const TensorDataNumpy *>(&other);
|
||||
if (ptr == nullptr) {
|
||||
// Not same type, compare data byte by byte.
|
||||
return TensorData::equals(other);
|
||||
}
|
||||
return NumpyEquals(*ptr);
|
||||
}
|
||||
|
||||
bool NumpyEquals(const TensorDataNumpy &other) const {
|
||||
auto all_data_equal = [&other, this]() -> bool {
|
||||
auto np = py::module::import("numpy");
|
||||
auto equal = np.attr("equal")(data_, other.data_);
|
||||
auto all_equal = np.attr("all")(equal);
|
||||
return all_equal.cast<bool>();
|
||||
};
|
||||
return this == &other || data_.is(other.data_) || all_data_equal();
|
||||
}
|
||||
const void *const_data() const override { return buffer_.ptr; }
|
||||
|
||||
/// To string.
|
||||
std::string ToString(const TypeId type, const ShapeVector &shape, bool use_comma) const override {
|
||||
|
@ -174,17 +147,21 @@ class TensorDataNumpy : public TensorData {
|
|||
kwargs["separator"] = ", ";
|
||||
auto np = py::module::import("numpy");
|
||||
auto array2string = np.attr("array2string");
|
||||
return py::str(array2string(data_, **kwargs));
|
||||
return py::str(array2string(py_array(), **kwargs));
|
||||
}
|
||||
// without comma.
|
||||
return py::str(data_);
|
||||
return py::str(py_array());
|
||||
}
|
||||
|
||||
/// py::array object.
|
||||
py::array py_array() const { return data_; }
|
||||
py::array py_array() const {
|
||||
// Use dummy owner to avoid copy data.
|
||||
py::str dummyOwner;
|
||||
return py::array(py::dtype(buffer_), buffer_.shape, buffer_.strides, buffer_.ptr, dummyOwner);
|
||||
}
|
||||
|
||||
private:
|
||||
mutable py::array data_;
|
||||
py::buffer_info buffer_;
|
||||
};
|
||||
|
||||
TensorPtr TensorPy::MakeTensor(const py::array &input, const TypePtr &type_ptr) {
|
||||
|
@ -226,6 +203,10 @@ TensorPtr TensorPy::MakeTensor(const py::array &input, const TypePtr &type_ptr)
|
|||
|
||||
/// Creates a Tensor from a numpy array without copy
|
||||
TensorPtr TensorPy::MakeTensorNoCopy(const py::array &input) {
|
||||
// Check format.
|
||||
if (!IsCContiguous(input)) {
|
||||
MS_LOG(EXCEPTION) << "Array should be C contiguous.";
|
||||
}
|
||||
// Get input buffer info.
|
||||
py::buffer_info buf = input.request();
|
||||
// Get tensor dtype and check it.
|
||||
|
@ -236,7 +217,7 @@ TensorPtr TensorPy::MakeTensorNoCopy(const py::array &input) {
|
|||
// Get tensor shape.
|
||||
ShapeVector shape(buf.shape.begin(), buf.shape.end());
|
||||
// Make a tensor with shared data with numpy array.
|
||||
auto tensor_data = std::make_shared<TensorDataNumpy>(input);
|
||||
auto tensor_data = std::make_shared<TensorDataNumpy>(std::move(buf));
|
||||
return std::make_shared<Tensor>(dtype, shape, tensor_data);
|
||||
}
|
||||
|
||||
|
|
|
@ -42,6 +42,8 @@ namespace tensor {
|
|||
// Tensor data interface.
|
||||
class TensorData {
|
||||
public:
|
||||
/// virtual destructor is required for base classes.
|
||||
virtual ~TensorData() = default;
|
||||
/// Total number of elements.
|
||||
virtual ssize_t size() const = 0;
|
||||
/// Byte size of a single element.
|
||||
|
|
Loading…
Reference in New Issue