forked from mindspore-Ecosystem/mindspore
!5616 Fix problems in Tensor.from_numpy()
Merge pull request !5616 from hewei/fix_tensor_from_numpy2
This commit is contained in:
commit
59b1358641
|
@ -120,51 +120,24 @@ static bool IsCContiguous(const py::array &input) {
|
||||||
// TensorDataNumpy implements TensorData using numpy array.
|
// TensorDataNumpy implements TensorData using numpy array.
|
||||||
class TensorDataNumpy : public TensorData {
|
class TensorDataNumpy : public TensorData {
|
||||||
public:
|
public:
|
||||||
explicit TensorDataNumpy(const py::array &input) : data_(input) {
|
explicit TensorDataNumpy(py::buffer_info &&buffer) : buffer_(std::move(buffer)) {}
|
||||||
if (!IsCContiguous(data_)) {
|
|
||||||
// Call numpy.ascontiguousarray() to convert data to C contiguous if it is not.
|
|
||||||
auto np = py::module::import("numpy");
|
|
||||||
auto convert = np.attr("ascontiguousarray");
|
|
||||||
data_ = convert(data_);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Total number of elements.
|
/// Total number of elements.
|
||||||
ssize_t size() const override { return data_.size(); }
|
ssize_t size() const override { return buffer_.size; }
|
||||||
|
|
||||||
/// Byte size of a single element.
|
/// Byte size of a single element.
|
||||||
ssize_t itemsize() const override { return data_.itemsize(); }
|
ssize_t itemsize() const override { return buffer_.itemsize; }
|
||||||
|
|
||||||
/// Total number of bytes.
|
/// Total number of bytes.
|
||||||
ssize_t nbytes() const override { return data_.nbytes(); }
|
ssize_t nbytes() const override { return buffer_.itemsize * buffer_.size; }
|
||||||
|
|
||||||
/// Number of dimensions.
|
/// Number of dimensions.
|
||||||
ssize_t ndim() const override { return data_.ndim(); }
|
ssize_t ndim() const override { return buffer_.ndim; }
|
||||||
|
|
||||||
/// Data pointer.
|
/// Data pointer.
|
||||||
void *data() override { return data_.request().ptr; }
|
void *data() override { return buffer_.ptr; }
|
||||||
|
|
||||||
const void *const_data() const override { return data_.request().ptr; }
|
const void *const_data() const override { return buffer_.ptr; }
|
||||||
|
|
||||||
/// Is data equals.
|
|
||||||
bool equals(const TensorData &other) const override {
|
|
||||||
auto ptr = dynamic_cast<const TensorDataNumpy *>(&other);
|
|
||||||
if (ptr == nullptr) {
|
|
||||||
// Not same type, compare data byte by byte.
|
|
||||||
return TensorData::equals(other);
|
|
||||||
}
|
|
||||||
return NumpyEquals(*ptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool NumpyEquals(const TensorDataNumpy &other) const {
|
|
||||||
auto all_data_equal = [&other, this]() -> bool {
|
|
||||||
auto np = py::module::import("numpy");
|
|
||||||
auto equal = np.attr("equal")(data_, other.data_);
|
|
||||||
auto all_equal = np.attr("all")(equal);
|
|
||||||
return all_equal.cast<bool>();
|
|
||||||
};
|
|
||||||
return this == &other || data_.is(other.data_) || all_data_equal();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// To string.
|
/// To string.
|
||||||
std::string ToString(const TypeId type, const ShapeVector &shape, bool use_comma) const override {
|
std::string ToString(const TypeId type, const ShapeVector &shape, bool use_comma) const override {
|
||||||
|
@ -174,17 +147,21 @@ class TensorDataNumpy : public TensorData {
|
||||||
kwargs["separator"] = ", ";
|
kwargs["separator"] = ", ";
|
||||||
auto np = py::module::import("numpy");
|
auto np = py::module::import("numpy");
|
||||||
auto array2string = np.attr("array2string");
|
auto array2string = np.attr("array2string");
|
||||||
return py::str(array2string(data_, **kwargs));
|
return py::str(array2string(py_array(), **kwargs));
|
||||||
}
|
}
|
||||||
// without comma.
|
// without comma.
|
||||||
return py::str(data_);
|
return py::str(py_array());
|
||||||
}
|
}
|
||||||
|
|
||||||
/// py::array object.
|
/// py::array object.
|
||||||
py::array py_array() const { return data_; }
|
py::array py_array() const {
|
||||||
|
// Use dummy owner to avoid copy data.
|
||||||
|
py::str dummyOwner;
|
||||||
|
return py::array(py::dtype(buffer_), buffer_.shape, buffer_.strides, buffer_.ptr, dummyOwner);
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
mutable py::array data_;
|
py::buffer_info buffer_;
|
||||||
};
|
};
|
||||||
|
|
||||||
TensorPtr TensorPy::MakeTensor(const py::array &input, const TypePtr &type_ptr) {
|
TensorPtr TensorPy::MakeTensor(const py::array &input, const TypePtr &type_ptr) {
|
||||||
|
@ -226,6 +203,10 @@ TensorPtr TensorPy::MakeTensor(const py::array &input, const TypePtr &type_ptr)
|
||||||
|
|
||||||
/// Creates a Tensor from a numpy array without copy
|
/// Creates a Tensor from a numpy array without copy
|
||||||
TensorPtr TensorPy::MakeTensorNoCopy(const py::array &input) {
|
TensorPtr TensorPy::MakeTensorNoCopy(const py::array &input) {
|
||||||
|
// Check format.
|
||||||
|
if (!IsCContiguous(input)) {
|
||||||
|
MS_LOG(EXCEPTION) << "Array should be C contiguous.";
|
||||||
|
}
|
||||||
// Get input buffer info.
|
// Get input buffer info.
|
||||||
py::buffer_info buf = input.request();
|
py::buffer_info buf = input.request();
|
||||||
// Get tensor dtype and check it.
|
// Get tensor dtype and check it.
|
||||||
|
@ -236,7 +217,7 @@ TensorPtr TensorPy::MakeTensorNoCopy(const py::array &input) {
|
||||||
// Get tensor shape.
|
// Get tensor shape.
|
||||||
ShapeVector shape(buf.shape.begin(), buf.shape.end());
|
ShapeVector shape(buf.shape.begin(), buf.shape.end());
|
||||||
// Make a tensor with shared data with numpy array.
|
// Make a tensor with shared data with numpy array.
|
||||||
auto tensor_data = std::make_shared<TensorDataNumpy>(input);
|
auto tensor_data = std::make_shared<TensorDataNumpy>(std::move(buf));
|
||||||
return std::make_shared<Tensor>(dtype, shape, tensor_data);
|
return std::make_shared<Tensor>(dtype, shape, tensor_data);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -42,6 +42,8 @@ namespace tensor {
|
||||||
// Tensor data interface.
|
// Tensor data interface.
|
||||||
class TensorData {
|
class TensorData {
|
||||||
public:
|
public:
|
||||||
|
/// virtual destructor is required for base classes.
|
||||||
|
virtual ~TensorData() = default;
|
||||||
/// Total number of elements.
|
/// Total number of elements.
|
||||||
virtual ssize_t size() const = 0;
|
virtual ssize_t size() const = 0;
|
||||||
/// Byte size of a single element.
|
/// Byte size of a single element.
|
||||||
|
|
Loading…
Reference in New Issue