forked from mindspore-Ecosystem/mindspore
!5330 Add comma seperator for python tensor __repr__().
Merge pull request !5330 from ZhangQinghua/master
This commit is contained in:
commit
41009707e6
|
@ -210,7 +210,7 @@ bool ConvertDataItem2Tensor(const std::vector<tdt::DataItem> &items) {
|
|||
mindspore::tensor::Tensor print_tensor(type_id, tensor_shape);
|
||||
auto memory_size = totaldims * type_size_map[item.tensorType_];
|
||||
if (PrintTensorToString(str_data_ptr->data(), &print_tensor, memory_size)) {
|
||||
buf << print_tensor.ToStringRepr() << std::endl;
|
||||
buf << print_tensor.ToStringNoLimit() << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -213,7 +213,7 @@ class TensorDataImpl : public TensorData {
|
|||
std::equal(data_.get(), data_.get() + data_size_, ptr->data_.get());
|
||||
}
|
||||
|
||||
std::string ToString(const TypeId type, const ShapeVector &shape) const override {
|
||||
std::string ToString(const TypeId type, const ShapeVector &shape, bool use_comma) const override {
|
||||
constexpr auto valid =
|
||||
std::is_same<T, bool>::value || std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value ||
|
||||
std::is_same<T, int16_t>::value || std::is_same<T, int32_t>::value || std::is_same<T, int64_t>::value ||
|
||||
|
@ -229,16 +229,16 @@ class TensorDataImpl : public TensorData {
|
|||
|
||||
std::ostringstream ss;
|
||||
if (data_size_ == 1 && ndim_ == 0) { // Scalar
|
||||
OutputDataString(ss, 0, 0, 1);
|
||||
OutputDataString(ss, 0, 0, 1, false);
|
||||
return ss.str();
|
||||
}
|
||||
ssize_t cursor = 0;
|
||||
SummaryStringRecursive(ss, shape, &cursor, 0);
|
||||
SummaryStringRecursive(ss, shape, &cursor, 0, use_comma);
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
private:
|
||||
void OutputDataString(std::ostringstream &ss, ssize_t cursor, ssize_t start, ssize_t end) const {
|
||||
void OutputDataString(std::ostringstream &ss, ssize_t cursor, ssize_t start, ssize_t end, bool use_comma) const {
|
||||
const bool isScalar = ndim_ == 0 && end - start == 1;
|
||||
constexpr auto isFloat =
|
||||
std::is_same<T, float16>::value || std::is_same<T, float>::value || std::is_same<T, double>::value;
|
||||
|
@ -265,33 +265,43 @@ class TensorDataImpl : public TensorData {
|
|||
ss << std::setw(5) << std::setiosflags(std::ios::right) << (value ? "True" : "False");
|
||||
}
|
||||
} else {
|
||||
constexpr auto isSigned = std::is_same<T, int8_t>::value || std::is_same<T, int16_t>::value ||
|
||||
std::is_same<T, int32_t>::value || std::is_same<T, int64_t>::value;
|
||||
constexpr auto isSigned = std::is_same<T, int64_t>::value;
|
||||
if constexpr (isSigned) {
|
||||
if (!isScalar && static_cast<int64_t>(value) >= 0) {
|
||||
ss << ' ';
|
||||
}
|
||||
}
|
||||
|
||||
// Set width and indent for different int type.
|
||||
// Set width and indent for different int type with signed position.
|
||||
//
|
||||
// int8/uint8 width: 3
|
||||
// int16/uint16 width: 5
|
||||
// int32/uint32 width: 10
|
||||
// int64/uint64 width: NOT SET
|
||||
if constexpr (std::is_same<T, int8_t>::value) {
|
||||
ss << std::setw(3) << std::setiosflags(std::ios::right) << static_cast<int16_t>(value);
|
||||
} else if constexpr (std::is_same<T, uint8_t>::value) {
|
||||
// uint8 width: 3, [0, 255]
|
||||
// int8 width: 4, [-128, 127]
|
||||
// uint16 width: 5, [0, 65535]
|
||||
// int16 width: 6, [-32768, 32767]
|
||||
// uint32 width: 10, [0, 4294967295]
|
||||
// int32 width: 11, [-2147483648, 2147483647]
|
||||
// uint64 width: NOT SET (20, [0, 18446744073709551615])
|
||||
// int64 width: NOT SET (20, [-9223372036854775808, 9223372036854775807])
|
||||
if constexpr (std::is_same<T, uint8_t>::value) {
|
||||
ss << std::setw(3) << std::setiosflags(std::ios::right) << static_cast<uint16_t>(value);
|
||||
} else if constexpr (std::is_same<T, int16_t>::value || std::is_same<T, uint16_t>::value) {
|
||||
} else if constexpr (std::is_same<T, int8_t>::value) {
|
||||
ss << std::setw(4) << std::setiosflags(std::ios::right) << static_cast<int16_t>(value);
|
||||
} else if constexpr (std::is_same<T, uint16_t>::value) {
|
||||
ss << std::setw(5) << std::setiosflags(std::ios::right) << value;
|
||||
} else if constexpr (std::is_same<T, int32_t>::value || std::is_same<T, uint32_t>::value) {
|
||||
} else if constexpr (std::is_same<T, int16_t>::value) {
|
||||
ss << std::setw(6) << std::setiosflags(std::ios::right) << value;
|
||||
} else if constexpr (std::is_same<T, uint32_t>::value) {
|
||||
ss << std::setw(10) << std::setiosflags(std::ios::right) << value;
|
||||
} else if constexpr (std::is_same<T, int32_t>::value) {
|
||||
ss << std::setw(11) << std::setiosflags(std::ios::right) << value;
|
||||
} else {
|
||||
ss << value;
|
||||
}
|
||||
}
|
||||
if (!isScalar && i != end - 1) {
|
||||
if (use_comma) {
|
||||
ss << ',';
|
||||
}
|
||||
ss << ' ';
|
||||
}
|
||||
if (!isScalar && ndim_ == 1 && (i + 1) % linefeedThreshold == 0) {
|
||||
|
@ -301,7 +311,8 @@ class TensorDataImpl : public TensorData {
|
|||
}
|
||||
}
|
||||
|
||||
void SummaryStringRecursive(std::ostringstream &ss, const ShapeVector &shape, ssize_t *cursor, ssize_t depth) const {
|
||||
void SummaryStringRecursive(std::ostringstream &ss, const ShapeVector &shape, ssize_t *cursor, ssize_t depth,
|
||||
bool use_comma) const {
|
||||
if (depth >= static_cast<ssize_t>(ndim_)) {
|
||||
return;
|
||||
}
|
||||
|
@ -309,11 +320,11 @@ class TensorDataImpl : public TensorData {
|
|||
if (depth == static_cast<ssize_t>(ndim_) - 1) { // Bottom dimension
|
||||
ssize_t num = shape[depth];
|
||||
if (num > kThreshold && ndim_ > 1) {
|
||||
OutputDataString(ss, *cursor, 0, kThreshold / 2);
|
||||
OutputDataString(ss, *cursor, 0, kThreshold / 2, use_comma);
|
||||
ss << ' ' << kEllipsis << ' ';
|
||||
OutputDataString(ss, *cursor, num - kThreshold / 2, num);
|
||||
OutputDataString(ss, *cursor, num - kThreshold / 2, num, use_comma);
|
||||
} else {
|
||||
OutputDataString(ss, *cursor, 0, num);
|
||||
OutputDataString(ss, *cursor, 0, num, use_comma);
|
||||
}
|
||||
*cursor += num;
|
||||
} else { // Middle dimension
|
||||
|
@ -321,13 +332,19 @@ class TensorDataImpl : public TensorData {
|
|||
// Handle the first half.
|
||||
for (ssize_t i = 0; i < std::min(static_cast<ssize_t>(kThreshold / 2), num); i++) {
|
||||
if (i > 0) {
|
||||
if (use_comma) {
|
||||
ss << ',';
|
||||
}
|
||||
ss << '\n';
|
||||
ss << std::setw(depth + 1) << ' '; // Add the indent.
|
||||
}
|
||||
SummaryStringRecursive(ss, shape, cursor, depth + 1);
|
||||
SummaryStringRecursive(ss, shape, cursor, depth + 1, use_comma);
|
||||
}
|
||||
// Handle the ignored part.
|
||||
if (num > kThreshold) {
|
||||
if (use_comma) {
|
||||
ss << ',';
|
||||
}
|
||||
ss << '\n';
|
||||
ss << std::setw(depth + 1) << ' '; // Add the indent.
|
||||
ss << kEllipsis;
|
||||
|
@ -343,10 +360,14 @@ class TensorDataImpl : public TensorData {
|
|||
}
|
||||
// Handle the second half.
|
||||
if (num > kThreshold / 2) {
|
||||
for (ssize_t i = num - kThreshold / 2; i < num; i++) {
|
||||
auto continue_pos = num - kThreshold / 2;
|
||||
for (ssize_t i = continue_pos; i < num; i++) {
|
||||
if (use_comma && i != continue_pos) {
|
||||
ss << ',';
|
||||
}
|
||||
ss << '\n';
|
||||
ss << std::setw(depth + 1) << ' '; // Add the indent.
|
||||
SummaryStringRecursive(ss, shape, cursor, depth + 1);
|
||||
SummaryStringRecursive(ss, shape, cursor, depth + 1, use_comma);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -487,29 +508,35 @@ std::string Tensor::GetShapeAndDataTypeInfo() const {
|
|||
return buf.str();
|
||||
}
|
||||
|
||||
std::string Tensor::ToString() const {
|
||||
constexpr int small_tensor_size = 30;
|
||||
std::string Tensor::ToStringInternal(int limit_size) const {
|
||||
std::ostringstream buf;
|
||||
auto dtype = Dtype();
|
||||
MS_EXCEPTION_IF_NULL(dtype);
|
||||
data_sync();
|
||||
buf << "Tensor(shape=" << ShapeToString(shape_) << ", dtype=" << dtype->ToString() << ',';
|
||||
if (DataSize() < small_tensor_size) {
|
||||
if (limit_size <= 0 || DataSize() < limit_size) {
|
||||
// Only print data for small tensor.
|
||||
buf << ((data().ndim() > 1) ? '\n' : ' ') << data().ToString(data_type_, shape_) << ')';
|
||||
buf << ((data().ndim() > 1) ? '\n' : ' ') << data().ToString(data_type_, shape_, false) << ')';
|
||||
} else {
|
||||
buf << " [...])";
|
||||
}
|
||||
return buf.str();
|
||||
}
|
||||
|
||||
std::string Tensor::ToString() const {
|
||||
constexpr int small_tensor_size = 30;
|
||||
return ToStringInternal(small_tensor_size);
|
||||
}
|
||||
|
||||
std::string Tensor::ToStringNoLimit() const { return ToStringInternal(0); }
|
||||
|
||||
std::string Tensor::ToStringRepr() const {
|
||||
std::ostringstream buf;
|
||||
auto dtype = Dtype();
|
||||
MS_EXCEPTION_IF_NULL(dtype);
|
||||
data_sync();
|
||||
buf << "Tensor(shape=" << ShapeToString(shape_) << ", dtype=" << dtype->ToString() << ','
|
||||
<< ((data().ndim() > 1) ? '\n' : ' ') << data().ToString(data_type_, shape_) << ')';
|
||||
<< ((data().ndim() > 1) ? '\n' : ' ') << data().ToString(data_type_, shape_, true) << ')';
|
||||
return buf.str();
|
||||
}
|
||||
|
||||
|
|
|
@ -53,7 +53,7 @@ class TensorData {
|
|||
/// Is data equals.
|
||||
virtual bool equals(const TensorData &other) const = 0;
|
||||
/// To string.
|
||||
virtual std::string ToString(const TypeId type, const ShapeVector &shape) const = 0;
|
||||
virtual std::string ToString(const TypeId type, const ShapeVector &shape, bool use_comma) const = 0;
|
||||
};
|
||||
|
||||
using TensorDataPtr = std::shared_ptr<TensorData>;
|
||||
|
@ -208,6 +208,10 @@ class Tensor : public MetaTensor {
|
|||
|
||||
std::string GetShapeAndDataTypeInfo() const;
|
||||
|
||||
std::string ToStringInternal(int limit_size) const;
|
||||
|
||||
std::string ToStringNoLimit() const;
|
||||
|
||||
std::string ToString() const override;
|
||||
|
||||
std::string ToStringRepr() const;
|
||||
|
|
Loading…
Reference in New Issue