diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp index 810672e3d7ed..22eff2dc34b9 100644 --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -1526,7 +1526,12 @@ static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType, // Check for the splat case. if (attr.isSplat()) { - processElt(*attr.begin(), /*index=*/0); + if (bitWidth == 1) { + // Handle the special encoding of splat of bool. + data[0] = mapping(*attr.begin()).isZero() ? 0 : -1; + } else { + processElt(*attr.begin(), /*index=*/0); + } return newArrayType; } diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp index e393b83df78d..cffac41bd953 100644 --- a/mlir/unittests/IR/AttributeTest.cpp +++ b/mlir/unittests/IR/AttributeTest.cpp @@ -209,6 +209,40 @@ TEST(DenseScalarTest, ExtractZeroRankElement) { auto attr = DenseElementsAttr::get(shape, llvm::makeArrayRef({elementValue})); EXPECT_TRUE(attr.getValues()[0] == value); } + +TEST(DenseSplatMapValuesTest, I32ToTrue) { + MLIRContext context; + const int elementValue = 12; + IntegerType boolTy = IntegerType::get(&context, 1); + IntegerType intTy = IntegerType::get(&context, 32); + RankedTensorType shape = RankedTensorType::get({4}, intTy); + + auto attr = + DenseElementsAttr::get(shape, llvm::makeArrayRef({elementValue})) + .mapValues(boolTy, [](const APInt &x) { + return x.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1); + }); + EXPECT_EQ(attr.getNumElements(), 4); + EXPECT_TRUE(attr.isSplat()); + EXPECT_TRUE(attr.getSplatValue().getValue()); +} + +TEST(DenseSplatMapValuesTest, I32ToFalse) { + MLIRContext context; + const int elementValue = 0; + IntegerType boolTy = IntegerType::get(&context, 1); + IntegerType intTy = IntegerType::get(&context, 32); + RankedTensorType shape = RankedTensorType::get({4}, intTy); + + auto attr = + DenseElementsAttr::get(shape, llvm::makeArrayRef({elementValue})) + .mapValues(boolTy, [](const APInt &x) { + return x.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1); + }); + EXPECT_EQ(attr.getNumElements(), 4); + EXPECT_TRUE(attr.isSplat()); + EXPECT_FALSE(attr.getSplatValue().getValue()); +} } // namespace //===----------------------------------------------------------------------===//