forked from mindspore-Ecosystem/mindspore
support uint for mindir datatype
This commit is contained in:
parent
7187887b63
commit
96d656aa00
|
@ -31,7 +31,7 @@
|
|||
namespace mindspore {
|
||||
using FloatPtr = std::shared_ptr<Float>;
|
||||
using IntPtr = std::shared_ptr<Int>;
|
||||
|
||||
using UIntPtr = std::shared_ptr<UInt>;
|
||||
// anf type to mindir type map
|
||||
static std::unordered_map<int, mind_ir::TensorProto_DataType> g_data_type_map = {
|
||||
{kNumberTypeBool, mind_ir::TensorProto_DataType_BOOL},
|
||||
|
@ -56,6 +56,13 @@ static std::unordered_map<int, mind_ir::TensorProto_DataType> g_data_bits_int_ma
|
|||
{64, mind_ir::TensorProto_DataType_INT64},
|
||||
};
|
||||
|
||||
static std::unordered_map<int, mind_ir::TensorProto_DataType> g_data_bits_uint_map = {
|
||||
{8, mind_ir::TensorProto_DataType_UINT8},
|
||||
{16, mind_ir::TensorProto_DataType_UINT16},
|
||||
{32, mind_ir::TensorProto_DataType_UINT32},
|
||||
{64, mind_ir::TensorProto_DataType_UINT64},
|
||||
};
|
||||
|
||||
static std::unordered_map<int, mind_ir::TensorProto_DataType> g_data_bits_float_map = {
|
||||
{16, mind_ir::TensorProto_DataType_FLOAT16},
|
||||
{32, mind_ir::TensorProto_DataType_FLOAT},
|
||||
|
@ -117,6 +124,7 @@ class IrExportBuilder {
|
|||
mind_ir::TensorProto_DataType GetMindirDataType(TypeId type_id);
|
||||
mind_ir::TensorProto_DataType GetMindirDataBitsIntType(int bits);
|
||||
mind_ir::TensorProto_DataType GetMindirDataBitsFloatType(int bits);
|
||||
mind_ir::TensorProto_DataType GetMindirDataBitsUIntType(int bits);
|
||||
std::string GetNodeName(const AnfNodePtr &node);
|
||||
std::string GetUniqueNodeName(const AnfNodePtr &node);
|
||||
std::string GetOpTypeName(const AnfNodePtr &node);
|
||||
|
@ -243,6 +251,14 @@ mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsIntType(int bits
|
|||
return iter->second;
|
||||
}
|
||||
|
||||
mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsUIntType(int bits) {
|
||||
auto iter = g_data_bits_uint_map.find(bits);
|
||||
if (iter == g_data_bits_uint_map.end()) {
|
||||
MS_LOG(EXCEPTION) << "Convert bits uint error, unsupported bits! " << bits;
|
||||
}
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsFloatType(int bits) {
|
||||
auto iter = g_data_bits_float_map.find(bits);
|
||||
if (iter == g_data_bits_float_map.end()) {
|
||||
|
@ -551,6 +567,11 @@ void IrExportBuilder::SetTypeToAttributeProto(const ValuePtr &value, mind_ir::At
|
|||
tensor_proto->set_name("value0");
|
||||
auto int_value = value->cast<IntPtr>();
|
||||
tensor_proto->set_data_type(GetMindirDataBitsIntType(int_value->nbits()));
|
||||
} else if (value->isa<UInt>()) {
|
||||
attr_proto->set_ref_attr_name("type:value0");
|
||||
tensor_proto->set_name("value0");
|
||||
auto float_value = value->cast<UIntPtr>();
|
||||
tensor_proto->set_data_type(GetMindirDataBitsUIntType(float_value->nbits()));
|
||||
} else if (value->isa<Float>()) {
|
||||
attr_proto->set_ref_attr_name("type:value0");
|
||||
tensor_proto->set_name("value0");
|
||||
|
|
Loading…
Reference in New Issue