Fix problems in Tensor.from_numpy()

Use buffer_info instead of py::array in TensorDataNumpy,
to avoid requesting buffer every time the data() is called.
This is also a workaround for py::array.request() hangup problem.
This commit is contained in:
He Wei 2020-09-01 10:37:44 +08:00
parent d5e02cf474
commit cbfd4c5f40
2 changed files with 22 additions and 39 deletions

View File

@ -120,51 +120,24 @@ static bool IsCContiguous(const py::array &input) {
// TensorDataNumpy implements TensorData using numpy array.
class TensorDataNumpy : public TensorData {
public:
explicit TensorDataNumpy(const py::array &input) : data_(input) {
if (!IsCContiguous(data_)) {
// Call numpy.ascontiguousarray() to convert data to C contiguous if it is not.
auto np = py::module::import("numpy");
auto convert = np.attr("ascontiguousarray");
data_ = convert(data_);
}
}
explicit TensorDataNumpy(py::buffer_info &&buffer) : buffer_(std::move(buffer)) {}
/// Total number of elements.
ssize_t size() const override { return data_.size(); }
ssize_t size() const override { return buffer_.size; }
/// Byte size of a single element.
ssize_t itemsize() const override { return data_.itemsize(); }
ssize_t itemsize() const override { return buffer_.itemsize; }
/// Total number of bytes.
ssize_t nbytes() const override { return data_.nbytes(); }
ssize_t nbytes() const override { return buffer_.itemsize * buffer_.size; }
/// Number of dimensions.
ssize_t ndim() const override { return data_.ndim(); }
ssize_t ndim() const override { return buffer_.ndim; }
/// Data pointer.
void *data() override { return data_.request().ptr; }
void *data() override { return buffer_.ptr; }
const void *const_data() const override { return data_.request().ptr; }
/// Is data equals.
bool equals(const TensorData &other) const override {
auto ptr = dynamic_cast<const TensorDataNumpy *>(&other);
if (ptr == nullptr) {
// Not same type, compare data byte by byte.
return TensorData::equals(other);
}
return NumpyEquals(*ptr);
}
bool NumpyEquals(const TensorDataNumpy &other) const {
auto all_data_equal = [&other, this]() -> bool {
auto np = py::module::import("numpy");
auto equal = np.attr("equal")(data_, other.data_);
auto all_equal = np.attr("all")(equal);
return all_equal.cast<bool>();
};
return this == &other || data_.is(other.data_) || all_data_equal();
}
const void *const_data() const override { return buffer_.ptr; }
/// To string.
std::string ToString(const TypeId type, const ShapeVector &shape, bool use_comma) const override {
@ -174,17 +147,21 @@ class TensorDataNumpy : public TensorData {
kwargs["separator"] = ", ";
auto np = py::module::import("numpy");
auto array2string = np.attr("array2string");
return py::str(array2string(data_, **kwargs));
return py::str(array2string(py_array(), **kwargs));
}
// without comma.
return py::str(data_);
return py::str(py_array());
}
/// py::array object.
py::array py_array() const { return data_; }
py::array py_array() const {
// Use dummy owner to avoid copy data.
py::str dummyOwner;
return py::array(py::dtype(buffer_), buffer_.shape, buffer_.strides, buffer_.ptr, dummyOwner);
}
private:
mutable py::array data_;
py::buffer_info buffer_;
};
TensorPtr TensorPy::MakeTensor(const py::array &input, const TypePtr &type_ptr) {
@ -226,6 +203,10 @@ TensorPtr TensorPy::MakeTensor(const py::array &input, const TypePtr &type_ptr)
/// Creates a Tensor from a numpy array without copy
TensorPtr TensorPy::MakeTensorNoCopy(const py::array &input) {
// Check format.
if (!IsCContiguous(input)) {
MS_LOG(EXCEPTION) << "Array should be C contiguous.";
}
// Get input buffer info.
py::buffer_info buf = input.request();
// Get tensor dtype and check it.
@ -236,7 +217,7 @@ TensorPtr TensorPy::MakeTensorNoCopy(const py::array &input) {
// Get tensor shape.
ShapeVector shape(buf.shape.begin(), buf.shape.end());
// Make a tensor with shared data with numpy array.
auto tensor_data = std::make_shared<TensorDataNumpy>(input);
auto tensor_data = std::make_shared<TensorDataNumpy>(std::move(buf));
return std::make_shared<Tensor>(dtype, shape, tensor_data);
}

View File

@ -42,6 +42,8 @@ namespace tensor {
// Tensor data interface.
class TensorData {
public:
/// virtual destructor is required for base classes.
virtual ~TensorData() = default;
/// Total number of elements.
virtual ssize_t size() const = 0;
/// Byte size of a single element.