!124 add suport for bool in print_kernels
Merge pull request !124 from yankai10/add_support_for_bool
This commit is contained in:
commit
e579be472f
|
@ -37,13 +37,13 @@ static std::map<std::string, TypeId> print_type_map = {
|
||||||
{"int32_t", TypeId::kNumberTypeInt32}, {"uint32_t", TypeId::kNumberTypeUInt32},
|
{"int32_t", TypeId::kNumberTypeInt32}, {"uint32_t", TypeId::kNumberTypeUInt32},
|
||||||
{"int64_t", TypeId::kNumberTypeInt64}, {"uint64_t", TypeId::kNumberTypeUInt64},
|
{"int64_t", TypeId::kNumberTypeInt64}, {"uint64_t", TypeId::kNumberTypeUInt64},
|
||||||
{"float16", TypeId::kNumberTypeFloat16}, {"float", TypeId::kNumberTypeFloat32},
|
{"float16", TypeId::kNumberTypeFloat16}, {"float", TypeId::kNumberTypeFloat32},
|
||||||
{"double", TypeId::kNumberTypeFloat64}};
|
{"double", TypeId::kNumberTypeFloat64}, {"bool", TypeId::kNumberTypeBool}};
|
||||||
|
|
||||||
static std::map<std::string, size_t> type_size_map = {
|
static std::map<std::string, size_t> type_size_map = {
|
||||||
{"int8_t", sizeof(int8_t)}, {"uint8_t", sizeof(uint8_t)}, {"int16_t", sizeof(int16_t)},
|
{"int8_t", sizeof(int8_t)}, {"uint8_t", sizeof(uint8_t)}, {"int16_t", sizeof(int16_t)},
|
||||||
{"uint16_t", sizeof(uint16_t)}, {"int32_t", sizeof(int32_t)}, {"uint32_t", sizeof(uint32_t)},
|
{"uint16_t", sizeof(uint16_t)}, {"int32_t", sizeof(int32_t)}, {"uint32_t", sizeof(uint32_t)},
|
||||||
{"int64_t", sizeof(int64_t)}, {"uint64_t", sizeof(uint64_t)}, {"float16", sizeof(float) / 2},
|
{"int64_t", sizeof(int64_t)}, {"uint64_t", sizeof(uint64_t)}, {"float16", sizeof(float) / 2},
|
||||||
{"float", sizeof(float)}, {"double", sizeof(double)}};
|
{"float", sizeof(float)}, {"double", sizeof(double)}, {"bool", sizeof(bool)}};
|
||||||
|
|
||||||
bool ParseTensorShape(const std::string &input_shape_str, std::vector<int> *const tensor_shape, size_t *dims) {
|
bool ParseTensorShape(const std::string &input_shape_str, std::vector<int> *const tensor_shape, size_t *dims) {
|
||||||
if (tensor_shape == nullptr) {
|
if (tensor_shape == nullptr) {
|
||||||
|
@ -107,7 +107,7 @@ bool ConvertDataItem2Tensor(const std::vector<tdt::DataItem> &items) {
|
||||||
} else {
|
} else {
|
||||||
auto type_iter = print_type_map.find(item.tensorType_);
|
auto type_iter = print_type_map.find(item.tensorType_);
|
||||||
if (type_iter == print_type_map.end()) {
|
if (type_iter == print_type_map.end()) {
|
||||||
MS_LOG(ERROR) << "type of tensor need to print is not soupport" << item.tensorType_;
|
MS_LOG(ERROR) << "type of tensor need to print is not support " << item.tensorType_;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto type_id = type_iter->second;
|
auto type_id = type_iter->second;
|
||||||
|
|
Loading…
Reference in New Issue