support uint for mindir datatype

This commit is contained in:
lianliguang 2021-07-29 16:06:24 +08:00
parent 7187887b63
commit 96d656aa00
1 changed files with 22 additions and 1 deletions

View File

@ -31,7 +31,7 @@
namespace mindspore {
using FloatPtr = std::shared_ptr<Float>;
using IntPtr = std::shared_ptr<Int>;
using UIntPtr = std::shared_ptr<UInt>;
// anf type to mindir type map
static std::unordered_map<int, mind_ir::TensorProto_DataType> g_data_type_map = {
{kNumberTypeBool, mind_ir::TensorProto_DataType_BOOL},
@ -56,6 +56,13 @@ static std::unordered_map<int, mind_ir::TensorProto_DataType> g_data_bits_int_ma
{64, mind_ir::TensorProto_DataType_INT64},
};
static std::unordered_map<int, mind_ir::TensorProto_DataType> g_data_bits_uint_map = {
{8, mind_ir::TensorProto_DataType_UINT8},
{16, mind_ir::TensorProto_DataType_UINT16},
{32, mind_ir::TensorProto_DataType_UINT32},
{64, mind_ir::TensorProto_DataType_UINT64},
};
static std::unordered_map<int, mind_ir::TensorProto_DataType> g_data_bits_float_map = {
{16, mind_ir::TensorProto_DataType_FLOAT16},
{32, mind_ir::TensorProto_DataType_FLOAT},
@ -117,6 +124,7 @@ class IrExportBuilder {
mind_ir::TensorProto_DataType GetMindirDataType(TypeId type_id);
mind_ir::TensorProto_DataType GetMindirDataBitsIntType(int bits);
mind_ir::TensorProto_DataType GetMindirDataBitsFloatType(int bits);
mind_ir::TensorProto_DataType GetMindirDataBitsUIntType(int bits);
std::string GetNodeName(const AnfNodePtr &node);
std::string GetUniqueNodeName(const AnfNodePtr &node);
std::string GetOpTypeName(const AnfNodePtr &node);
@ -243,6 +251,14 @@ mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsIntType(int bits
return iter->second;
}
mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsUIntType(int bits) {
auto iter = g_data_bits_uint_map.find(bits);
if (iter == g_data_bits_uint_map.end()) {
MS_LOG(EXCEPTION) << "Convert bits uint error, unsupported bits! " << bits;
}
return iter->second;
}
mind_ir::TensorProto_DataType IrExportBuilder::GetMindirDataBitsFloatType(int bits) {
auto iter = g_data_bits_float_map.find(bits);
if (iter == g_data_bits_float_map.end()) {
@ -551,6 +567,11 @@ void IrExportBuilder::SetTypeToAttributeProto(const ValuePtr &value, mind_ir::At
tensor_proto->set_name("value0");
auto int_value = value->cast<IntPtr>();
tensor_proto->set_data_type(GetMindirDataBitsIntType(int_value->nbits()));
} else if (value->isa<UInt>()) {
attr_proto->set_ref_attr_name("type:value0");
tensor_proto->set_name("value0");
auto float_value = value->cast<UIntPtr>();
tensor_proto->set_data_type(GetMindirDataBitsUIntType(float_value->nbits()));
} else if (value->isa<Float>()) {
attr_proto->set_ref_attr_name("type:value0");
tensor_proto->set_name("value0");