forked from mindspore-Ecosystem/mindspore
fix scalar print
This commit is contained in:
parent
831ceba6eb
commit
7bf8d2696c
|
@ -32,6 +32,7 @@
|
|||
namespace mindspore {
|
||||
const char kShapeSeperator[] = ",";
|
||||
const char kShapeScalar[] = "[0]";
|
||||
const char kShapeNone[] = "[]";
|
||||
static std::map<std::string, TypeId> print_type_map = {
|
||||
{"int8_t", TypeId::kNumberTypeInt8}, {"uint8_t", TypeId::kNumberTypeUInt8},
|
||||
{"int16_t", TypeId::kNumberTypeInt16}, {"uint16_t", TypeId::kNumberTypeUInt16},
|
||||
|
@ -163,9 +164,9 @@ bool ConvertDataItem2Tensor(const std::vector<tdt::DataItem> &items) {
|
|||
}
|
||||
std::shared_ptr<std::string> str_data_ptr = std::static_pointer_cast<std::string>(item.dataPtr_);
|
||||
MS_EXCEPTION_IF_NULL(str_data_ptr);
|
||||
if (item.tensorShape_ == kShapeScalar) {
|
||||
if (item.tensorShape_ == kShapeScalar || item.tensorShape_ == kShapeNone) {
|
||||
if (!judgeLengthValid(str_data_ptr->size(), item.tensorType_)) {
|
||||
MS_LOG(EXCEPTION) << "Print op receive data length is invalid.";
|
||||
MS_LOG(EXCEPTION) << "Print op receive data length is invalid.";
|
||||
}
|
||||
convertDataItem2Scalar(str_data_ptr->data(), item.tensorType_, &buf);
|
||||
continue;
|
||||
|
|
Loading…
Reference in New Issue