!5616 Fix problems in Tensor.from_numpy()

Merge pull request !5616 from hewei/fix_tensor_from_numpy2
This commit is contained in:
mindspore-ci-bot 2020-09-01 19:59:01 +08:00 committed by Gitee
commit 59b1358641
2 changed files with 22 additions and 39 deletions

View File

@ -120,51 +120,24 @@ static bool IsCContiguous(const py::array &input) {
// TensorDataNumpy implements TensorData using numpy array. // TensorDataNumpy implements TensorData using numpy array.
class TensorDataNumpy : public TensorData { class TensorDataNumpy : public TensorData {
public: public:
explicit TensorDataNumpy(const py::array &input) : data_(input) { explicit TensorDataNumpy(py::buffer_info &&buffer) : buffer_(std::move(buffer)) {}
if (!IsCContiguous(data_)) {
// Call numpy.ascontiguousarray() to convert data to C contiguous if it is not.
auto np = py::module::import("numpy");
auto convert = np.attr("ascontiguousarray");
data_ = convert(data_);
}
}
/// Total number of elements. /// Total number of elements.
ssize_t size() const override { return data_.size(); } ssize_t size() const override { return buffer_.size; }
/// Byte size of a single element. /// Byte size of a single element.
ssize_t itemsize() const override { return data_.itemsize(); } ssize_t itemsize() const override { return buffer_.itemsize; }
/// Total number of bytes. /// Total number of bytes.
ssize_t nbytes() const override { return data_.nbytes(); } ssize_t nbytes() const override { return buffer_.itemsize * buffer_.size; }
/// Number of dimensions. /// Number of dimensions.
ssize_t ndim() const override { return data_.ndim(); } ssize_t ndim() const override { return buffer_.ndim; }
/// Data pointer. /// Data pointer.
void *data() override { return data_.request().ptr; } void *data() override { return buffer_.ptr; }
const void *const_data() const override { return data_.request().ptr; } const void *const_data() const override { return buffer_.ptr; }
/// Is data equals.
bool equals(const TensorData &other) const override {
auto ptr = dynamic_cast<const TensorDataNumpy *>(&other);
if (ptr == nullptr) {
// Not same type, compare data byte by byte.
return TensorData::equals(other);
}
return NumpyEquals(*ptr);
}
bool NumpyEquals(const TensorDataNumpy &other) const {
auto all_data_equal = [&other, this]() -> bool {
auto np = py::module::import("numpy");
auto equal = np.attr("equal")(data_, other.data_);
auto all_equal = np.attr("all")(equal);
return all_equal.cast<bool>();
};
return this == &other || data_.is(other.data_) || all_data_equal();
}
/// To string. /// To string.
std::string ToString(const TypeId type, const ShapeVector &shape, bool use_comma) const override { std::string ToString(const TypeId type, const ShapeVector &shape, bool use_comma) const override {
@ -174,17 +147,21 @@ class TensorDataNumpy : public TensorData {
kwargs["separator"] = ", "; kwargs["separator"] = ", ";
auto np = py::module::import("numpy"); auto np = py::module::import("numpy");
auto array2string = np.attr("array2string"); auto array2string = np.attr("array2string");
return py::str(array2string(data_, **kwargs)); return py::str(array2string(py_array(), **kwargs));
} }
// without comma. // without comma.
return py::str(data_); return py::str(py_array());
} }
/// py::array object. /// py::array object.
py::array py_array() const { return data_; } py::array py_array() const {
// Use dummy owner to avoid copy data.
py::str dummyOwner;
return py::array(py::dtype(buffer_), buffer_.shape, buffer_.strides, buffer_.ptr, dummyOwner);
}
private: private:
mutable py::array data_; py::buffer_info buffer_;
}; };
TensorPtr TensorPy::MakeTensor(const py::array &input, const TypePtr &type_ptr) { TensorPtr TensorPy::MakeTensor(const py::array &input, const TypePtr &type_ptr) {
@ -226,6 +203,10 @@ TensorPtr TensorPy::MakeTensor(const py::array &input, const TypePtr &type_ptr)
/// Creates a Tensor from a numpy array without copy /// Creates a Tensor from a numpy array without copy
TensorPtr TensorPy::MakeTensorNoCopy(const py::array &input) { TensorPtr TensorPy::MakeTensorNoCopy(const py::array &input) {
// Check format.
if (!IsCContiguous(input)) {
MS_LOG(EXCEPTION) << "Array should be C contiguous.";
}
// Get input buffer info. // Get input buffer info.
py::buffer_info buf = input.request(); py::buffer_info buf = input.request();
// Get tensor dtype and check it. // Get tensor dtype and check it.
@ -236,7 +217,7 @@ TensorPtr TensorPy::MakeTensorNoCopy(const py::array &input) {
// Get tensor shape. // Get tensor shape.
ShapeVector shape(buf.shape.begin(), buf.shape.end()); ShapeVector shape(buf.shape.begin(), buf.shape.end());
// Make a tensor with shared data with numpy array. // Make a tensor with shared data with numpy array.
auto tensor_data = std::make_shared<TensorDataNumpy>(input); auto tensor_data = std::make_shared<TensorDataNumpy>(std::move(buf));
return std::make_shared<Tensor>(dtype, shape, tensor_data); return std::make_shared<Tensor>(dtype, shape, tensor_data);
} }

View File

@ -42,6 +42,8 @@ namespace tensor {
// Tensor data interface. // Tensor data interface.
class TensorData { class TensorData {
public: public:
/// virtual destructor is required for base classes.
virtual ~TensorData() = default;
/// Total number of elements. /// Total number of elements.
virtual ssize_t size() const = 0; virtual ssize_t size() const = 0;
/// Byte size of a single element. /// Byte size of a single element.