forked from mindspore-Ecosystem/mindspore
Fix the issue of Tensor and SubModule.
This commit is contained in:
parent
3e691e54f5
commit
de95cc8c1c
|
@ -176,7 +176,7 @@ static const char *GetSubModuleName(SubModuleId module_id) {
|
|||
"PYNATIVE", // SM_PYNATIVE
|
||||
"SESSION", // SM_SESSION
|
||||
"UTILS", // SM_UTILS
|
||||
"VM" // SM_VM
|
||||
"VM", // SM_VM
|
||||
"ABSTRACT" // SM_ABSTRACT
|
||||
};
|
||||
|
||||
|
|
|
@ -185,6 +185,10 @@ class TensorDataImpl : public TensorData {
|
|||
}
|
||||
|
||||
std::ostringstream ss;
|
||||
if (data_size_ == 1 && ndim_ == 0) { // Scalar
|
||||
OutputDataString(ss, type, 0, 0, 1);
|
||||
return ss.str();
|
||||
}
|
||||
ssize_t cursor = 0;
|
||||
SummaryStringRecursive(ss, type, shape, &cursor, 0);
|
||||
return ss.str();
|
||||
|
@ -192,23 +196,32 @@ class TensorDataImpl : public TensorData {
|
|||
|
||||
private:
|
||||
void OutputDataString(std::ostringstream &ss, const TypeId type, ssize_t cursor, ssize_t start, ssize_t end) const {
|
||||
bool isScalar = ndim_ == 0 && end - start == 1;
|
||||
int linefeedThreshold;
|
||||
constexpr auto isFloat =
|
||||
std::is_same<T, float16>::value || std::is_same<T, float>::value || std::is_same<T, double>::value;
|
||||
for (ssize_t i = start; i < end && (cursor + i) < static_cast<ssize_t>(data_size_); i++) {
|
||||
const auto value = data_[cursor + i];
|
||||
if constexpr (isFloat) {
|
||||
if (isScalar) {
|
||||
ss << value;
|
||||
} else {
|
||||
ss << std::setw(15) << std::setprecision(8) << std::setiosflags(std::ios::scientific | std::ios::right)
|
||||
<< value;
|
||||
}
|
||||
linefeedThreshold = kThreshold1DFloat;
|
||||
} else if (type == kNumberTypeBool) {
|
||||
if (isScalar) {
|
||||
ss << (value == 0 ? "False" : "True");
|
||||
} else {
|
||||
ss << std::setw(5) << std::setiosflags(std::ios::right) << (value == 0 ? "False" : "True");
|
||||
}
|
||||
linefeedThreshold = kThreshold1DBool;
|
||||
} else {
|
||||
constexpr auto isSigned = std::is_same<T, int8_t>::value || std::is_same<T, int16_t>::value ||
|
||||
std::is_same<T, int32_t>::value || std::is_same<T, int64_t>::value;
|
||||
if constexpr (isSigned) {
|
||||
if (static_cast<int64_t>(value) >= 0) {
|
||||
if (!isScalar && static_cast<int64_t>(value) >= 0) {
|
||||
ss << ' ';
|
||||
}
|
||||
}
|
||||
|
@ -221,10 +234,11 @@ class TensorDataImpl : public TensorData {
|
|||
}
|
||||
linefeedThreshold = kThreshold1DInt;
|
||||
}
|
||||
if (i != end - 1) {
|
||||
if (!isScalar && i != end - 1) {
|
||||
ss << ' ';
|
||||
}
|
||||
if (ndim_ == 1 && (i + 1) % linefeedThreshold == 0) { // Add a line feed every {threshold of type} for 1D tensor.
|
||||
if (!isScalar && ndim_ == 1 && (i + 1) % linefeedThreshold == 0) {
|
||||
// Add a line feed every {threshold of type} for 1D tensor.
|
||||
ss << '\n' << ' ';
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue