diff --git a/mindspore/ccsrc/pybind_api/ir/tensor_py.cc b/mindspore/ccsrc/pybind_api/ir/tensor_py.cc index 65818589f43..56c549c9cd9 100644 --- a/mindspore/ccsrc/pybind_api/ir/tensor_py.cc +++ b/mindspore/ccsrc/pybind_api/ir/tensor_py.cc @@ -120,51 +120,24 @@ static bool IsCContiguous(const py::array &input) { // TensorDataNumpy implements TensorData using numpy array. class TensorDataNumpy : public TensorData { public: - explicit TensorDataNumpy(const py::array &input) : data_(input) { - if (!IsCContiguous(data_)) { - // Call numpy.ascontiguousarray() to convert data to C contiguous if it is not. - auto np = py::module::import("numpy"); - auto convert = np.attr("ascontiguousarray"); - data_ = convert(data_); - } - } + explicit TensorDataNumpy(py::buffer_info &&buffer) : buffer_(std::move(buffer)) {} /// Total number of elements. - ssize_t size() const override { return data_.size(); } + ssize_t size() const override { return buffer_.size; } /// Byte size of a single element. - ssize_t itemsize() const override { return data_.itemsize(); } + ssize_t itemsize() const override { return buffer_.itemsize; } /// Total number of bytes. - ssize_t nbytes() const override { return data_.nbytes(); } + ssize_t nbytes() const override { return buffer_.itemsize * buffer_.size; } /// Number of dimensions. - ssize_t ndim() const override { return data_.ndim(); } + ssize_t ndim() const override { return buffer_.ndim; } /// Data pointer. - void *data() override { return data_.request().ptr; } + void *data() override { return buffer_.ptr; } - const void *const_data() const override { return data_.request().ptr; } - - /// Is data equals. - bool equals(const TensorData &other) const override { - auto ptr = dynamic_cast<const TensorDataNumpy *>(&other); - if (ptr == nullptr) { - // Not same type, compare data byte by byte. - return TensorData::equals(other); - } - return NumpyEquals(*ptr); - } - - bool NumpyEquals(const TensorDataNumpy &other) const { - auto all_data_equal = [&other, this]() -> bool { - auto np = py::module::import("numpy"); - auto equal = np.attr("equal")(data_, other.data_); - auto all_equal = np.attr("all")(equal); - return all_equal.cast<bool>(); - }; - return this == &other || data_.is(other.data_) || all_data_equal(); - } + const void *const_data() const override { return buffer_.ptr; } /// To string. std::string ToString(const TypeId type, const ShapeVector &shape, bool use_comma) const override { @@ -174,17 +147,21 @@ class TensorDataNumpy : public TensorData { kwargs["separator"] = ", "; auto np = py::module::import("numpy"); auto array2string = np.attr("array2string"); - return py::str(array2string(data_, **kwargs)); + return py::str(array2string(py_array(), **kwargs)); } // without comma. - return py::str(data_); + return py::str(py_array()); } /// py::array object. - py::array py_array() const { return data_; } + py::array py_array() const { + // Use dummy owner to avoid copy data. + py::str dummyOwner; + return py::array(py::dtype(buffer_), buffer_.shape, buffer_.strides, buffer_.ptr, dummyOwner); + } private: - mutable py::array data_; + py::buffer_info buffer_; }; TensorPtr TensorPy::MakeTensor(const py::array &input, const TypePtr &type_ptr) { @@ -226,6 +203,10 @@ TensorPtr TensorPy::MakeTensor(const py::array &input, const TypePtr &type_ptr) /// Creates a Tensor from a numpy array without copy TensorPtr TensorPy::MakeTensorNoCopy(const py::array &input) { + // Check format. + if (!IsCContiguous(input)) { + MS_LOG(EXCEPTION) << "Array should be C contiguous."; + } // Get input buffer info. py::buffer_info buf = input.request(); // Get tensor dtype and check it. @@ -236,7 +217,7 @@ TensorPtr TensorPy::MakeTensorNoCopy(const py::array &input) { // Get tensor shape. ShapeVector shape(buf.shape.begin(), buf.shape.end()); // Make a tensor with shared data with numpy array. - auto tensor_data = std::make_shared<TensorDataNumpy>(input); + auto tensor_data = std::make_shared<TensorDataNumpy>(std::move(buf)); return std::make_shared<Tensor>(dtype, shape, tensor_data); } diff --git a/mindspore/core/ir/tensor.h b/mindspore/core/ir/tensor.h index 1228131c197..eb84cc1fce1 100644 --- a/mindspore/core/ir/tensor.h +++ b/mindspore/core/ir/tensor.h @@ -42,6 +42,8 @@ namespace tensor { // Tensor data interface. class TensorData { public: + /// virtual destructor is required for base classes. + virtual ~TensorData() = default; /// Total number of elements. virtual ssize_t size() const = 0; /// Byte size of a single element.