diff --git a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp index 14fb485a56c9..c91a72344c59 100644 --- a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp +++ b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp @@ -280,52 +280,48 @@ static unsigned cachedLookup(Type t, DenseMap &cache, unsigned mlir::DataLayout::getTypeSize(Type t) const { checkValid(); return cachedLookup(t, sizes, [&](Type ty) { - if (originalLayout) { - DataLayoutEntryList list = originalLayout.getSpecForType(ty.getTypeID()); - if (auto iface = dyn_cast(scope)) - return iface.getTypeSize(ty, *this, list); - return detail::getDefaultTypeSize(ty, *this, list); - } - return detail::getDefaultTypeSize(ty, *this, {}); + DataLayoutEntryList list; + if (originalLayout) + list = originalLayout.getSpecForType(ty.getTypeID()); + if (auto iface = dyn_cast_or_null(scope)) + return iface.getTypeSize(ty, *this, list); + return detail::getDefaultTypeSize(ty, *this, list); }); } unsigned mlir::DataLayout::getTypeSizeInBits(Type t) const { checkValid(); return cachedLookup(t, bitsizes, [&](Type ty) { - if (originalLayout) { - DataLayoutEntryList list = originalLayout.getSpecForType(ty.getTypeID()); - if (auto iface = dyn_cast(scope)) - return iface.getTypeSizeInBits(ty, *this, list); - return detail::getDefaultTypeSizeInBits(ty, *this, list); - } - return detail::getDefaultTypeSizeInBits(ty, *this, {}); + DataLayoutEntryList list; + if (originalLayout) + list = originalLayout.getSpecForType(ty.getTypeID()); + if (auto iface = dyn_cast_or_null(scope)) + return iface.getTypeSizeInBits(ty, *this, list); + return detail::getDefaultTypeSizeInBits(ty, *this, list); }); } unsigned mlir::DataLayout::getTypeABIAlignment(Type t) const { checkValid(); return cachedLookup(t, abiAlignments, [&](Type ty) { - if (originalLayout) { - DataLayoutEntryList list = originalLayout.getSpecForType(ty.getTypeID()); - if (auto iface = dyn_cast(scope)) - return iface.getTypeABIAlignment(ty, *this, list); - return detail::getDefaultABIAlignment(ty, *this, list); - } - return detail::getDefaultABIAlignment(ty, *this, {}); + DataLayoutEntryList list; + if (originalLayout) + list = originalLayout.getSpecForType(ty.getTypeID()); + if (auto iface = dyn_cast_or_null(scope)) + return iface.getTypeABIAlignment(ty, *this, list); + return detail::getDefaultABIAlignment(ty, *this, list); }); } unsigned mlir::DataLayout::getTypePreferredAlignment(Type t) const { checkValid(); return cachedLookup(t, preferredAlignments, [&](Type ty) { - if (originalLayout) { - DataLayoutEntryList list = originalLayout.getSpecForType(ty.getTypeID()); - if (auto iface = dyn_cast(scope)) - return iface.getTypePreferredAlignment(ty, *this, list); - return detail::getDefaultPreferredAlignment(ty, *this, list); - } - return detail::getDefaultPreferredAlignment(ty, *this, {}); + DataLayoutEntryList list; + if (originalLayout) + list = originalLayout.getSpecForType(ty.getTypeID()); + if (auto iface = dyn_cast_or_null(scope)) + return iface.getTypePreferredAlignment(ty, *this, list); + return detail::getDefaultPreferredAlignment(ty, *this, list); }); } diff --git a/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp b/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp index e9d69f02442d..287839120ab2 100644 --- a/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp +++ b/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp @@ -227,6 +227,27 @@ struct DLTestDialect : Dialect { TEST(DataLayout, FallbackDefault) { const char *ir = R"MLIR( +module {} + )MLIR"; + + DialectRegistry registry; + registry.insert(); + MLIRContext ctx(registry); + + OwningModuleRef module = parseSourceString(ir, &ctx); + DataLayout layout(module.get()); + EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 42)), 6u); + EXPECT_EQ(layout.getTypeSize(Float16Type::get(&ctx)), 2u); + EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 42)), 42u); + EXPECT_EQ(layout.getTypeSizeInBits(Float16Type::get(&ctx)), 16u); + EXPECT_EQ(layout.getTypeABIAlignment(IntegerType::get(&ctx, 42)), 8u); + EXPECT_EQ(layout.getTypeABIAlignment(Float16Type::get(&ctx)), 2u); + EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 42)), 8u); + EXPECT_EQ(layout.getTypePreferredAlignment(Float16Type::get(&ctx)), 2u); +} + +TEST(DataLayout, NullSpec) { + const char *ir = R"MLIR( "dltest.op_with_layout"() : () -> () )MLIR"; @@ -238,14 +259,14 @@ TEST(DataLayout, FallbackDefault) { auto op = cast(module->getBody()->getOperations().front()); DataLayout layout(op); - EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 42)), 6u); - EXPECT_EQ(layout.getTypeSize(Float16Type::get(&ctx)), 2u); - EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 42)), 42u); - EXPECT_EQ(layout.getTypeSizeInBits(Float16Type::get(&ctx)), 16u); - EXPECT_EQ(layout.getTypeABIAlignment(IntegerType::get(&ctx, 42)), 8u); - EXPECT_EQ(layout.getTypeABIAlignment(Float16Type::get(&ctx)), 2u); - EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 42)), 8u); - EXPECT_EQ(layout.getTypePreferredAlignment(Float16Type::get(&ctx)), 2u); + EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 42)), 42u); + EXPECT_EQ(layout.getTypeSize(Float16Type::get(&ctx)), 16u); + EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 42)), 8u * 42u); + EXPECT_EQ(layout.getTypeSizeInBits(Float16Type::get(&ctx)), 8u * 16u); + EXPECT_EQ(layout.getTypeABIAlignment(IntegerType::get(&ctx, 42)), 64u); + EXPECT_EQ(layout.getTypeABIAlignment(Float16Type::get(&ctx)), 16u); + EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 42)), 128u); + EXPECT_EQ(layout.getTypePreferredAlignment(Float16Type::get(&ctx)), 32u); } TEST(DataLayout, EmptySpec) {