!17405 GetInputTensorValue supports more data type

From: @looop5
Reviewed-by: @gaoxiong1,@ckey_dou
Signed-off-by: @ckey_dou
This commit is contained in:
mindspore-ci-bot 2021-06-01 09:48:38 +08:00 committed by Gitee
commit dad20255d1
1 changed files with 24 additions and 16 deletions

View File

@ -708,27 +708,35 @@ bool GetInputTensorValue(const AnfNodePtr &anf_node, size_t input_idx, nlohmann:
return false;
}
if (type_id == kFloat32->type_id()) {
float *val = static_cast<float *>(data);
MS_EXCEPTION_IF_NULL(val);
(*node_json)["value"] = val[0];
MS_LOG(DEBUG) << "Value of tensor[" << cnode->DebugString() << "] is [float32][" << *val << "].";
return true;
if (type_id == kFloat64->type_id()) {
(*node_json)["value"] = static_cast<double *>(data)[0];
} else if (type_id == kFloat32->type_id()) {
(*node_json)["value"] = static_cast<float *>(data)[0];
} else if (type_id == kFloat16->type_id()) {
float16 *val = static_cast<float16 *>(data);
MS_EXCEPTION_IF_NULL(val);
(*node_json)["value"] = static_cast<float>(val[0]);
MS_LOG(INFO) << "Value of tensor[" << cnode->DebugString() << "] is [float16][" << *val << "].";
return true;
} else if (type_id == kUInt64->type_id()) {
(*node_json)["value"] = static_cast<uint64_t *>(data)[0];
} else if (type_id == kUInt32->type_id()) {
(*node_json)["value"] = static_cast<uint32_t *>(data)[0];
} else if (type_id == kUInt16->type_id()) {
(*node_json)["value"] = static_cast<uint16_t *>(data)[0];
} else if (type_id == kUInt8->type_id()) {
(*node_json)["value"] = static_cast<uint8_t *>(data)[0];
} else if (type_id == kInt64->type_id()) {
(*node_json)["value"] = static_cast<int64_t *>(data)[0];
} else if (type_id == kInt32->type_id()) {
int *val = static_cast<int *>(data);
MS_EXCEPTION_IF_NULL(val);
(*node_json)["value"] = val[0];
MS_LOG(INFO) << "Value of tensor[" << cnode->DebugString() << "] is [int32][" << *val << "].";
return true;
(*node_json)["value"] = static_cast<int32_t *>(data)[0];
} else if (type_id == kInt16->type_id()) {
(*node_json)["value"] = static_cast<int16_t *>(data)[0];
} else if (type_id == kInt8->type_id()) {
(*node_json)["value"] = static_cast<int8_t *>(data)[0];
} else if (type_id == kBool->type_id()) {
(*node_json)["value"] = static_cast<bool *>(data)[0];
} else {
MS_LOG(EXCEPTION) << "Unknown value type of tensor[" << cnode->DebugString() << "]";
}
MS_LOG(ERROR) << "Unknown value type of tensor[" << cnode->DebugString() << "]";
return false;
return true;
}
bool IsWeightBoundary(const AnfNodePtr &node) {