forked from mindspore-Ecosystem/mindspore
fix scalar tensor shape=[]
This commit is contained in:
parent
74dc7d069d
commit
311e7be605
|
@ -103,7 +103,7 @@ template <typename T>
|
|||
void PrintScalarToString(const char *str_data_ptr, const string &tensor_type, std::ostringstream *const buf) {
|
||||
MS_EXCEPTION_IF_NULL(str_data_ptr);
|
||||
MS_EXCEPTION_IF_NULL(buf);
|
||||
*buf << "Tensor(shape=[1], dtype=" << GetParseType(tensor_type) << ", value=";
|
||||
*buf << "Tensor(shape=[], dtype=" << GetParseType(tensor_type) << ", value=";
|
||||
const T *data_ptr = reinterpret_cast<const T *>(str_data_ptr);
|
||||
if constexpr (std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value) {
|
||||
const int int_data = static_cast<int>(*data_ptr);
|
||||
|
@ -117,7 +117,7 @@ void PrintScalarToBoolString(const char *str_data_ptr, const string &tensor_type
|
|||
MS_EXCEPTION_IF_NULL(str_data_ptr);
|
||||
MS_EXCEPTION_IF_NULL(buf);
|
||||
const bool *data_ptr = reinterpret_cast<const bool *>(str_data_ptr);
|
||||
*buf << "Tensor(shape=[1], dtype=" << GetParseType(tensor_type) << ", value=";
|
||||
*buf << "Tensor(shape=[], dtype=" << GetParseType(tensor_type) << ", value=";
|
||||
if (*data_ptr) {
|
||||
*buf << "True)\n";
|
||||
} else {
|
||||
|
|
|
@ -25,7 +25,7 @@ expect_array = {'Bool': '\n[[ True False]\n [False True]]', 'UInt': '\n[[1 2 3]
|
|||
'[ *.********e*** **.********e*** *.********e***]]'}
|
||||
|
||||
def get_expect_value(res):
|
||||
if res[0] == '[1]':
|
||||
if res[0] == '[]':
|
||||
if res[1] == 'Bool':
|
||||
return expect_scalar['Bool']
|
||||
if res[1] in ['Uint8', 'Uint16', 'Uint32', 'Uint64']:
|
||||
|
|
Loading…
Reference in New Issue