forked from mindspore-Ecosystem/mindspore
!1202 Fix tensor print order
Merge pull request !1202 from zjun/fix_tensor_print
This commit is contained in:
commit
f23bfe0d71
|
@ -50,6 +50,7 @@ bool ParseTensorShape(const std::string &input_shape_str, std::vector<int> *cons
|
|||
if (tensor_shape == nullptr) {
|
||||
return false;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(dims);
|
||||
std::string shape_str = input_shape_str;
|
||||
if (shape_str.size() <= 2) {
|
||||
return false;
|
||||
|
@ -71,6 +72,8 @@ bool ParseTensorShape(const std::string &input_shape_str, std::vector<int> *cons
|
|||
|
||||
bool PrintTensorToString(const char *str_data_ptr, mindspore::tensor::Tensor *const print_tensor,
|
||||
const size_t &memory_size) {
|
||||
MS_EXCEPTION_IF_NULL(str_data_ptr);
|
||||
MS_EXCEPTION_IF_NULL(print_tensor);
|
||||
auto *tensor_data_ptr = static_cast<uint8_t *>(print_tensor->data_c(true));
|
||||
MS_EXCEPTION_IF_NULL(tensor_data_ptr);
|
||||
auto cp_ret =
|
||||
|
@ -83,55 +86,57 @@ bool PrintTensorToString(const char *str_data_ptr, mindspore::tensor::Tensor *co
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void PrintScalarToString(const char *str_data_ptr, const string &tensor_type) {
|
||||
void PrintScalarToString(const char *str_data_ptr, const string &tensor_type, std::ostringstream *buf) {
|
||||
MS_EXCEPTION_IF_NULL(str_data_ptr);
|
||||
MS_EXCEPTION_IF_NULL(buf);
|
||||
const T *data_ptr = reinterpret_cast<const T *>(str_data_ptr);
|
||||
std::ostringstream buf_scalar;
|
||||
buf_scalar << "Tensor shape :1 " << tensor_type;
|
||||
buf_scalar << "\nval:";
|
||||
buf_scalar << *data_ptr;
|
||||
std::cout << buf_scalar.str() << std::endl;
|
||||
*buf << "Tensor shape:[1] " << tensor_type;
|
||||
*buf << "\nval:";
|
||||
*buf << *data_ptr << "\n";
|
||||
}
|
||||
|
||||
void PrintScalarToBoolString(const char *str_data_ptr, const string &tensor_type) {
|
||||
void PrintScalarToBoolString(const char *str_data_ptr, const string &tensor_type, std::ostringstream *buf) {
|
||||
MS_EXCEPTION_IF_NULL(str_data_ptr);
|
||||
MS_EXCEPTION_IF_NULL(buf);
|
||||
const bool *data_ptr = reinterpret_cast<const bool *>(str_data_ptr);
|
||||
std::ostringstream buf_scalar;
|
||||
buf_scalar << "Tensor shape :1 " << tensor_type;
|
||||
buf_scalar << "\nval:";
|
||||
if (*data_ptr == true) {
|
||||
buf_scalar << "True";
|
||||
*buf << "Tensor shape:[1] " << tensor_type;
|
||||
*buf << "\nval:";
|
||||
if (*data_ptr) {
|
||||
*buf << "True\n";
|
||||
} else {
|
||||
buf_scalar << "False";
|
||||
*buf << "False\n";
|
||||
}
|
||||
std::cout << buf_scalar.str() << std::endl;
|
||||
}
|
||||
|
||||
void convertDataItem2Scalar(const char *str_data_ptr, const string &tensor_type) {
|
||||
void convertDataItem2Scalar(const char *str_data_ptr, const string &tensor_type, std::ostringstream *buf) {
|
||||
MS_EXCEPTION_IF_NULL(str_data_ptr);
|
||||
MS_EXCEPTION_IF_NULL(buf);
|
||||
auto type_iter = print_type_map.find(tensor_type);
|
||||
auto type_id = type_iter->second;
|
||||
if (type_id == TypeId::kNumberTypeBool) {
|
||||
PrintScalarToBoolString(str_data_ptr, tensor_type);
|
||||
PrintScalarToBoolString(str_data_ptr, tensor_type, buf);
|
||||
} else if (type_id == TypeId::kNumberTypeInt8) {
|
||||
PrintScalarToString<int8_t>(str_data_ptr, tensor_type);
|
||||
PrintScalarToString<int8_t>(str_data_ptr, tensor_type, buf);
|
||||
} else if (type_id == TypeId::kNumberTypeUInt8) {
|
||||
PrintScalarToString<uint8_t>(str_data_ptr, tensor_type);
|
||||
PrintScalarToString<uint8_t>(str_data_ptr, tensor_type, buf);
|
||||
} else if (type_id == TypeId::kNumberTypeInt16) {
|
||||
PrintScalarToString<int16_t>(str_data_ptr, tensor_type);
|
||||
PrintScalarToString<int16_t>(str_data_ptr, tensor_type, buf);
|
||||
} else if (type_id == TypeId::kNumberTypeUInt16) {
|
||||
PrintScalarToString<uint16_t>(str_data_ptr, tensor_type);
|
||||
PrintScalarToString<uint16_t>(str_data_ptr, tensor_type, buf);
|
||||
} else if (type_id == TypeId::kNumberTypeInt32) {
|
||||
PrintScalarToString<int32_t>(str_data_ptr, tensor_type);
|
||||
PrintScalarToString<int32_t>(str_data_ptr, tensor_type, buf);
|
||||
} else if (type_id == TypeId::kNumberTypeUInt32) {
|
||||
PrintScalarToString<uint32_t>(str_data_ptr, tensor_type);
|
||||
PrintScalarToString<uint32_t>(str_data_ptr, tensor_type, buf);
|
||||
} else if (type_id == TypeId::kNumberTypeInt64) {
|
||||
PrintScalarToString<int64_t>(str_data_ptr, tensor_type);
|
||||
PrintScalarToString<int64_t>(str_data_ptr, tensor_type, buf);
|
||||
} else if (type_id == TypeId::kNumberTypeUInt64) {
|
||||
PrintScalarToString<uint64_t>(str_data_ptr, tensor_type);
|
||||
PrintScalarToString<uint64_t>(str_data_ptr, tensor_type, buf);
|
||||
} else if (type_id == TypeId::kNumberTypeFloat16) {
|
||||
PrintScalarToString<float16>(str_data_ptr, tensor_type);
|
||||
PrintScalarToString<float16>(str_data_ptr, tensor_type, buf);
|
||||
} else if (type_id == TypeId::kNumberTypeFloat32) {
|
||||
PrintScalarToString<float>(str_data_ptr, tensor_type);
|
||||
PrintScalarToString<float>(str_data_ptr, tensor_type, buf);
|
||||
} else if (type_id == TypeId::kNumberTypeFloat64) {
|
||||
PrintScalarToString<double>(str_data_ptr, tensor_type);
|
||||
PrintScalarToString<double>(str_data_ptr, tensor_type, buf);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Cannot print scalar because of unsupport data type: " << tensor_type << ".";
|
||||
}
|
||||
|
@ -142,11 +147,7 @@ bool judgeLengthValid(const size_t str_len, const string &tensor_type) {
|
|||
if (type_iter == type_size_map.end()) {
|
||||
MS_LOG(EXCEPTION) << "type of scalar to print is not support.";
|
||||
}
|
||||
|
||||
if (str_len != type_iter->second) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
return str_len == type_iter->second;
|
||||
}
|
||||
|
||||
#ifndef NO_DLIB
|
||||
|
@ -166,7 +167,7 @@ bool ConvertDataItem2Tensor(const std::vector<tdt::DataItem> &items) {
|
|||
if (!judgeLengthValid(str_data_ptr->size(), item.tensorType_)) {
|
||||
MS_LOG(EXCEPTION) << "Print op receive data length is invalid.";
|
||||
}
|
||||
convertDataItem2Scalar(str_data_ptr->data(), item.tensorType_);
|
||||
convertDataItem2Scalar(str_data_ptr->data(), item.tensorType_, &buf);
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue