forked from mindspore-Ecosystem/mindspore
!17405 GetInputTensorValue supports more data type
From: @looop5 Reviewed-by: @gaoxiong1,@ckey_dou Signed-off-by: @ckey_dou
This commit is contained in:
commit
dad20255d1
|
@ -708,27 +708,35 @@ bool GetInputTensorValue(const AnfNodePtr &anf_node, size_t input_idx, nlohmann:
|
|||
return false;
|
||||
}
|
||||
|
||||
if (type_id == kFloat32->type_id()) {
|
||||
float *val = static_cast<float *>(data);
|
||||
MS_EXCEPTION_IF_NULL(val);
|
||||
(*node_json)["value"] = val[0];
|
||||
MS_LOG(DEBUG) << "Value of tensor[" << cnode->DebugString() << "] is [float32][" << *val << "].";
|
||||
return true;
|
||||
if (type_id == kFloat64->type_id()) {
|
||||
(*node_json)["value"] = static_cast<double *>(data)[0];
|
||||
} else if (type_id == kFloat32->type_id()) {
|
||||
(*node_json)["value"] = static_cast<float *>(data)[0];
|
||||
} else if (type_id == kFloat16->type_id()) {
|
||||
float16 *val = static_cast<float16 *>(data);
|
||||
MS_EXCEPTION_IF_NULL(val);
|
||||
(*node_json)["value"] = static_cast<float>(val[0]);
|
||||
MS_LOG(INFO) << "Value of tensor[" << cnode->DebugString() << "] is [float16][" << *val << "].";
|
||||
return true;
|
||||
} else if (type_id == kUInt64->type_id()) {
|
||||
(*node_json)["value"] = static_cast<uint64_t *>(data)[0];
|
||||
} else if (type_id == kUInt32->type_id()) {
|
||||
(*node_json)["value"] = static_cast<uint32_t *>(data)[0];
|
||||
} else if (type_id == kUInt16->type_id()) {
|
||||
(*node_json)["value"] = static_cast<uint16_t *>(data)[0];
|
||||
} else if (type_id == kUInt8->type_id()) {
|
||||
(*node_json)["value"] = static_cast<uint8_t *>(data)[0];
|
||||
} else if (type_id == kInt64->type_id()) {
|
||||
(*node_json)["value"] = static_cast<int64_t *>(data)[0];
|
||||
} else if (type_id == kInt32->type_id()) {
|
||||
int *val = static_cast<int *>(data);
|
||||
MS_EXCEPTION_IF_NULL(val);
|
||||
(*node_json)["value"] = val[0];
|
||||
MS_LOG(INFO) << "Value of tensor[" << cnode->DebugString() << "] is [int32][" << *val << "].";
|
||||
return true;
|
||||
(*node_json)["value"] = static_cast<int32_t *>(data)[0];
|
||||
} else if (type_id == kInt16->type_id()) {
|
||||
(*node_json)["value"] = static_cast<int16_t *>(data)[0];
|
||||
} else if (type_id == kInt8->type_id()) {
|
||||
(*node_json)["value"] = static_cast<int8_t *>(data)[0];
|
||||
} else if (type_id == kBool->type_id()) {
|
||||
(*node_json)["value"] = static_cast<bool *>(data)[0];
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Unknown value type of tensor[" << cnode->DebugString() << "]";
|
||||
}
|
||||
MS_LOG(ERROR) << "Unknown value type of tensor[" << cnode->DebugString() << "]";
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsWeightBoundary(const AnfNodePtr &node) {
|
||||
|
|
Loading…
Reference in New Issue