forked from OSchip/llvm-project
[mlir] Fix DenseElementsAttr::mapValues(i1, splat).
Splat of bool is encoded as a byte with all-ones in it [1]. Without this
change, this piece of code:
auto xs = builder.getI32TensorAttr({42, 42, 42, 42});
auto xs2 = xs.mapValues(builder.getI1Type(), [](const llvm::APInt &x) {
return x.isZero() ? llvm::APInt::getZero(1) : llvm::APInt::getAllOnes(1);
});
xs2.dump();
Prints:
dense<[true, false, false, false]> : tensor<4xi1>
Because only the first bit is set. This applies to both
DenseIntElementsAttr::mapValues() and DenseFPElementsAttr::mapValues().
[1]: e877b42e2c/mlir/lib/IR/BuiltinAttributes.cpp (L984)
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D132767
This commit is contained in:
parent
7f2b016b82
commit
ae60a4a0ef
|
@ -1526,7 +1526,12 @@ static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType,
|
|||
|
||||
// Check for the splat case.
|
||||
if (attr.isSplat()) {
|
||||
processElt(*attr.begin(), /*index=*/0);
|
||||
if (bitWidth == 1) {
|
||||
// Handle the special encoding of splat of bool.
|
||||
data[0] = mapping(*attr.begin()).isZero() ? 0 : -1;
|
||||
} else {
|
||||
processElt(*attr.begin(), /*index=*/0);
|
||||
}
|
||||
return newArrayType;
|
||||
}
|
||||
|
||||
|
|
|
@ -209,6 +209,40 @@ TEST(DenseScalarTest, ExtractZeroRankElement) {
|
|||
auto attr = DenseElementsAttr::get(shape, llvm::makeArrayRef({elementValue}));
|
||||
EXPECT_TRUE(attr.getValues<Attribute>()[0] == value);
|
||||
}
|
||||
|
||||
TEST(DenseSplatMapValuesTest, I32ToTrue) {
|
||||
MLIRContext context;
|
||||
const int elementValue = 12;
|
||||
IntegerType boolTy = IntegerType::get(&context, 1);
|
||||
IntegerType intTy = IntegerType::get(&context, 32);
|
||||
RankedTensorType shape = RankedTensorType::get({4}, intTy);
|
||||
|
||||
auto attr =
|
||||
DenseElementsAttr::get(shape, llvm::makeArrayRef({elementValue}))
|
||||
.mapValues(boolTy, [](const APInt &x) {
|
||||
return x.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1);
|
||||
});
|
||||
EXPECT_EQ(attr.getNumElements(), 4);
|
||||
EXPECT_TRUE(attr.isSplat());
|
||||
EXPECT_TRUE(attr.getSplatValue<BoolAttr>().getValue());
|
||||
}
|
||||
|
||||
TEST(DenseSplatMapValuesTest, I32ToFalse) {
|
||||
MLIRContext context;
|
||||
const int elementValue = 0;
|
||||
IntegerType boolTy = IntegerType::get(&context, 1);
|
||||
IntegerType intTy = IntegerType::get(&context, 32);
|
||||
RankedTensorType shape = RankedTensorType::get({4}, intTy);
|
||||
|
||||
auto attr =
|
||||
DenseElementsAttr::get(shape, llvm::makeArrayRef({elementValue}))
|
||||
.mapValues(boolTy, [](const APInt &x) {
|
||||
return x.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1);
|
||||
});
|
||||
EXPECT_EQ(attr.getNumElements(), 4);
|
||||
EXPECT_TRUE(attr.isSplat());
|
||||
EXPECT_FALSE(attr.getSplatValue<BoolAttr>().getValue());
|
||||
}
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
Loading…
Reference in New Issue