From ae60a4a0efff337425638d04005b33a73dc5792f Mon Sep 17 00:00:00 2001 From: Chenguang Wang Date: Tue, 6 Sep 2022 21:28:23 +0200 Subject: [PATCH] [mlir] Fix DenseElementsAttr::mapValues(i1, splat). Splat of bool is encoded as a byte with all-ones in it [1]. Without this change, this piece of code: auto xs = builder.getI32TensorAttr({42, 42, 42, 42}); auto xs2 = xs.mapValues(builder.getI1Type(), [](const llvm::APInt &x) { return x.isZero() ? llvm::APInt::getZero(1) : llvm::APInt::getAllOnes(1); }); xs2.dump(); Prints: dense<[true, false, false, false]> : tensor<4xi1> Because only the first bit is set. This applies to both DenseIntElementsAttr::mapValues() and DenseFPElementsAttr::mapValues(). [1]: https://github.com/llvm/llvm-project/blob/e877b42e2c70813352c1963ea33e992f481d5cba/mlir/lib/IR/BuiltinAttributes.cpp#L984 Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D132767 --- mlir/lib/IR/BuiltinAttributes.cpp | 7 +++++- mlir/unittests/IR/AttributeTest.cpp | 34 +++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp index 810672e3d7ed..22eff2dc34b9 100644 --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -1526,7 +1526,12 @@ static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType, // Check for the splat case. if (attr.isSplat()) { - processElt(*attr.begin(), /*index=*/0); + if (bitWidth == 1) { + // Handle the special encoding of splat of bool. + data[0] = mapping(*attr.begin()).isZero() ? 0 : -1; + } else { + processElt(*attr.begin(), /*index=*/0); + } return newArrayType; } diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp index e393b83df78d..cffac41bd953 100644 --- a/mlir/unittests/IR/AttributeTest.cpp +++ b/mlir/unittests/IR/AttributeTest.cpp @@ -209,6 +209,40 @@ TEST(DenseScalarTest, ExtractZeroRankElement) { auto attr = DenseElementsAttr::get(shape, llvm::makeArrayRef({elementValue})); EXPECT_TRUE(attr.getValues()[0] == value); } + +TEST(DenseSplatMapValuesTest, I32ToTrue) { + MLIRContext context; + const int elementValue = 12; + IntegerType boolTy = IntegerType::get(&context, 1); + IntegerType intTy = IntegerType::get(&context, 32); + RankedTensorType shape = RankedTensorType::get({4}, intTy); + + auto attr = + DenseElementsAttr::get(shape, llvm::makeArrayRef({elementValue})) + .mapValues(boolTy, [](const APInt &x) { + return x.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1); + }); + EXPECT_EQ(attr.getNumElements(), 4); + EXPECT_TRUE(attr.isSplat()); + EXPECT_TRUE(attr.getSplatValue().getValue()); +} + +TEST(DenseSplatMapValuesTest, I32ToFalse) { + MLIRContext context; + const int elementValue = 0; + IntegerType boolTy = IntegerType::get(&context, 1); + IntegerType intTy = IntegerType::get(&context, 32); + RankedTensorType shape = RankedTensorType::get({4}, intTy); + + auto attr = + DenseElementsAttr::get(shape, llvm::makeArrayRef({elementValue})) + .mapValues(boolTy, [](const APInt &x) { + return x.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1); + }); + EXPECT_EQ(attr.getNumElements(), 4); + EXPECT_TRUE(attr.isSplat()); + EXPECT_FALSE(attr.getSplatValue().getValue()); +} } // namespace //===----------------------------------------------------------------------===//