Fix the issue of Tensor and SubModule.

This commit is contained in:
Zhang Qinghua 2020-07-15 15:49:39 +08:00
parent 3e691e54f5
commit de95cc8c1c
2 changed files with 21 additions and 7 deletions

View File

@ -176,7 +176,7 @@ static const char *GetSubModuleName(SubModuleId module_id) {
"PYNATIVE", // SM_PYNATIVE "PYNATIVE", // SM_PYNATIVE
"SESSION", // SM_SESSION "SESSION", // SM_SESSION
"UTILS", // SM_UTILS "UTILS", // SM_UTILS
"VM" // SM_VM "VM", // SM_VM
"ABSTRACT" // SM_ABSTRACT "ABSTRACT" // SM_ABSTRACT
}; };

View File

@ -185,6 +185,10 @@ class TensorDataImpl : public TensorData {
} }
std::ostringstream ss; std::ostringstream ss;
if (data_size_ == 1 && ndim_ == 0) { // Scalar
OutputDataString(ss, type, 0, 0, 1);
return ss.str();
}
ssize_t cursor = 0; ssize_t cursor = 0;
SummaryStringRecursive(ss, type, shape, &cursor, 0); SummaryStringRecursive(ss, type, shape, &cursor, 0);
return ss.str(); return ss.str();
@ -192,23 +196,32 @@ class TensorDataImpl : public TensorData {
private: private:
void OutputDataString(std::ostringstream &ss, const TypeId type, ssize_t cursor, ssize_t start, ssize_t end) const { void OutputDataString(std::ostringstream &ss, const TypeId type, ssize_t cursor, ssize_t start, ssize_t end) const {
bool isScalar = ndim_ == 0 && end - start == 1;
int linefeedThreshold; int linefeedThreshold;
constexpr auto isFloat = constexpr auto isFloat =
std::is_same<T, float16>::value || std::is_same<T, float>::value || std::is_same<T, double>::value; std::is_same<T, float16>::value || std::is_same<T, float>::value || std::is_same<T, double>::value;
for (ssize_t i = start; i < end && (cursor + i) < static_cast<ssize_t>(data_size_); i++) { for (ssize_t i = start; i < end && (cursor + i) < static_cast<ssize_t>(data_size_); i++) {
const auto value = data_[cursor + i]; const auto value = data_[cursor + i];
if constexpr (isFloat) { if constexpr (isFloat) {
ss << std::setw(15) << std::setprecision(8) << std::setiosflags(std::ios::scientific | std::ios::right) if (isScalar) {
<< value; ss << value;
} else {
ss << std::setw(15) << std::setprecision(8) << std::setiosflags(std::ios::scientific | std::ios::right)
<< value;
}
linefeedThreshold = kThreshold1DFloat; linefeedThreshold = kThreshold1DFloat;
} else if (type == kNumberTypeBool) { } else if (type == kNumberTypeBool) {
ss << std::setw(5) << std::setiosflags(std::ios::right) << (value == 0 ? "False" : "True"); if (isScalar) {
ss << (value == 0 ? "False" : "True");
} else {
ss << std::setw(5) << std::setiosflags(std::ios::right) << (value == 0 ? "False" : "True");
}
linefeedThreshold = kThreshold1DBool; linefeedThreshold = kThreshold1DBool;
} else { } else {
constexpr auto isSigned = std::is_same<T, int8_t>::value || std::is_same<T, int16_t>::value || constexpr auto isSigned = std::is_same<T, int8_t>::value || std::is_same<T, int16_t>::value ||
std::is_same<T, int32_t>::value || std::is_same<T, int64_t>::value; std::is_same<T, int32_t>::value || std::is_same<T, int64_t>::value;
if constexpr (isSigned) { if constexpr (isSigned) {
if (static_cast<int64_t>(value) >= 0) { if (!isScalar && static_cast<int64_t>(value) >= 0) {
ss << ' '; ss << ' ';
} }
} }
@ -221,10 +234,11 @@ class TensorDataImpl : public TensorData {
} }
linefeedThreshold = kThreshold1DInt; linefeedThreshold = kThreshold1DInt;
} }
if (i != end - 1) { if (!isScalar && i != end - 1) {
ss << ' '; ss << ' ';
} }
if (ndim_ == 1 && (i + 1) % linefeedThreshold == 0) { // Add a line feed every {threshold of type} for 1D tensor. if (!isScalar && ndim_ == 1 && (i + 1) % linefeedThreshold == 0) {
// Add a line feed every {threshold of type} for 1D tensor.
ss << '\n' << ' '; ss << '\n' << ' ';
} }
} }