forked from OSchip/llvm-project
411 lines
14 KiB
C++
411 lines
14 KiB
C++
//===- AttributeTest.cpp - Attribute unit tests ---------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/IR/AsmState.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/BuiltinAttributes.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "gtest/gtest.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::detail;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// DenseElementsAttr
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
template <typename EltTy>
|
|
static void testSplat(Type eltType, const EltTy &splatElt) {
|
|
RankedTensorType shape = RankedTensorType::get({2, 1}, eltType);
|
|
|
|
// Check that the generated splat is the same for 1 element and N elements.
|
|
DenseElementsAttr splat = DenseElementsAttr::get(shape, splatElt);
|
|
EXPECT_TRUE(splat.isSplat());
|
|
|
|
auto detectedSplat =
|
|
DenseElementsAttr::get(shape, llvm::makeArrayRef({splatElt, splatElt}));
|
|
EXPECT_EQ(detectedSplat, splat);
|
|
|
|
for (auto newValue : detectedSplat.template getValues<EltTy>())
|
|
EXPECT_TRUE(newValue == splatElt);
|
|
}
|
|
|
|
namespace {
|
|
TEST(DenseSplatTest, BoolSplat) {
|
|
MLIRContext context;
|
|
IntegerType boolTy = IntegerType::get(&context, 1);
|
|
RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy);
|
|
|
|
// Check that splat is automatically detected for boolean values.
|
|
/// True.
|
|
DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true);
|
|
EXPECT_TRUE(trueSplat.isSplat());
|
|
/// False.
|
|
DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false);
|
|
EXPECT_TRUE(falseSplat.isSplat());
|
|
EXPECT_NE(falseSplat, trueSplat);
|
|
|
|
/// Detect and handle splat within 8 elements (bool values are bit-packed).
|
|
/// True.
|
|
auto detectedSplat = DenseElementsAttr::get(shape, {true, true, true, true});
|
|
EXPECT_EQ(detectedSplat, trueSplat);
|
|
/// False.
|
|
detectedSplat = DenseElementsAttr::get(shape, {false, false, false, false});
|
|
EXPECT_EQ(detectedSplat, falseSplat);
|
|
}
|
|
|
|
TEST(DenseSplatTest, LargeBoolSplat) {
|
|
constexpr int64_t boolCount = 56;
|
|
|
|
MLIRContext context;
|
|
IntegerType boolTy = IntegerType::get(&context, 1);
|
|
RankedTensorType shape = RankedTensorType::get({boolCount}, boolTy);
|
|
|
|
// Check that splat is automatically detected for boolean values.
|
|
/// True.
|
|
DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true);
|
|
DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false);
|
|
EXPECT_TRUE(trueSplat.isSplat());
|
|
EXPECT_TRUE(falseSplat.isSplat());
|
|
|
|
/// Detect that the large boolean arrays are properly splatted.
|
|
/// True.
|
|
SmallVector<bool, 64> trueValues(boolCount, true);
|
|
auto detectedSplat = DenseElementsAttr::get(shape, trueValues);
|
|
EXPECT_EQ(detectedSplat, trueSplat);
|
|
/// False.
|
|
SmallVector<bool, 64> falseValues(boolCount, false);
|
|
detectedSplat = DenseElementsAttr::get(shape, falseValues);
|
|
EXPECT_EQ(detectedSplat, falseSplat);
|
|
}
|
|
|
|
TEST(DenseSplatTest, BoolNonSplat) {
|
|
MLIRContext context;
|
|
IntegerType boolTy = IntegerType::get(&context, 1);
|
|
RankedTensorType shape = RankedTensorType::get({6}, boolTy);
|
|
|
|
// Check that we properly handle non-splat values.
|
|
DenseElementsAttr nonSplat =
|
|
DenseElementsAttr::get(shape, {false, false, true, false, false, true});
|
|
EXPECT_FALSE(nonSplat.isSplat());
|
|
}
|
|
|
|
TEST(DenseSplatTest, OddIntSplat) {
|
|
// Test detecting a splat with an odd(non 8-bit) integer bitwidth.
|
|
MLIRContext context;
|
|
constexpr size_t intWidth = 19;
|
|
IntegerType intTy = IntegerType::get(&context, intWidth);
|
|
APInt value(intWidth, 10);
|
|
|
|
testSplat(intTy, value);
|
|
}
|
|
|
|
TEST(DenseSplatTest, Int32Splat) {
|
|
MLIRContext context;
|
|
IntegerType intTy = IntegerType::get(&context, 32);
|
|
int value = 64;
|
|
|
|
testSplat(intTy, value);
|
|
}
|
|
|
|
TEST(DenseSplatTest, IntAttrSplat) {
|
|
MLIRContext context;
|
|
IntegerType intTy = IntegerType::get(&context, 85);
|
|
Attribute value = IntegerAttr::get(intTy, 109);
|
|
|
|
testSplat(intTy, value);
|
|
}
|
|
|
|
TEST(DenseSplatTest, F32Splat) {
|
|
MLIRContext context;
|
|
FloatType floatTy = FloatType::getF32(&context);
|
|
float value = 10.0;
|
|
|
|
testSplat(floatTy, value);
|
|
}
|
|
|
|
TEST(DenseSplatTest, F64Splat) {
|
|
MLIRContext context;
|
|
FloatType floatTy = FloatType::getF64(&context);
|
|
double value = 10.0;
|
|
|
|
testSplat(floatTy, APFloat(value));
|
|
}
|
|
|
|
TEST(DenseSplatTest, FloatAttrSplat) {
|
|
MLIRContext context;
|
|
FloatType floatTy = FloatType::getF32(&context);
|
|
Attribute value = FloatAttr::get(floatTy, 10.0);
|
|
|
|
testSplat(floatTy, value);
|
|
}
|
|
|
|
TEST(DenseSplatTest, BF16Splat) {
|
|
MLIRContext context;
|
|
FloatType floatTy = FloatType::getBF16(&context);
|
|
Attribute value = FloatAttr::get(floatTy, 10.0);
|
|
|
|
testSplat(floatTy, value);
|
|
}
|
|
|
|
TEST(DenseSplatTest, StringSplat) {
|
|
MLIRContext context;
|
|
context.allowUnregisteredDialects();
|
|
Type stringType =
|
|
OpaqueType::get(StringAttr::get(&context, "test"), "string");
|
|
StringRef value = "test-string";
|
|
testSplat(stringType, value);
|
|
}
|
|
|
|
TEST(DenseSplatTest, StringAttrSplat) {
|
|
MLIRContext context;
|
|
context.allowUnregisteredDialects();
|
|
Type stringType =
|
|
OpaqueType::get(StringAttr::get(&context, "test"), "string");
|
|
Attribute stringAttr = StringAttr::get("test-string", stringType);
|
|
testSplat(stringType, stringAttr);
|
|
}
|
|
|
|
TEST(DenseComplexTest, ComplexFloatSplat) {
|
|
MLIRContext context;
|
|
ComplexType complexType = ComplexType::get(FloatType::getF32(&context));
|
|
std::complex<float> value(10.0, 15.0);
|
|
testSplat(complexType, value);
|
|
}
|
|
|
|
TEST(DenseComplexTest, ComplexIntSplat) {
|
|
MLIRContext context;
|
|
ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64));
|
|
std::complex<int64_t> value(10, 15);
|
|
testSplat(complexType, value);
|
|
}
|
|
|
|
TEST(DenseComplexTest, ComplexAPFloatSplat) {
|
|
MLIRContext context;
|
|
ComplexType complexType = ComplexType::get(FloatType::getF32(&context));
|
|
std::complex<APFloat> value(APFloat(10.0f), APFloat(15.0f));
|
|
testSplat(complexType, value);
|
|
}
|
|
|
|
TEST(DenseComplexTest, ComplexAPIntSplat) {
|
|
MLIRContext context;
|
|
ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64));
|
|
std::complex<APInt> value(APInt(64, 10), APInt(64, 15));
|
|
testSplat(complexType, value);
|
|
}
|
|
|
|
TEST(DenseScalarTest, ExtractZeroRankElement) {
|
|
MLIRContext context;
|
|
const int elementValue = 12;
|
|
IntegerType intTy = IntegerType::get(&context, 32);
|
|
Attribute value = IntegerAttr::get(intTy, elementValue);
|
|
RankedTensorType shape = RankedTensorType::get({}, intTy);
|
|
|
|
auto attr = DenseElementsAttr::get(shape, llvm::makeArrayRef({elementValue}));
|
|
EXPECT_TRUE(attr.getValues<Attribute>()[0] == value);
|
|
}
|
|
|
|
TEST(DenseSplatMapValuesTest, I32ToTrue) {
|
|
MLIRContext context;
|
|
const int elementValue = 12;
|
|
IntegerType boolTy = IntegerType::get(&context, 1);
|
|
IntegerType intTy = IntegerType::get(&context, 32);
|
|
RankedTensorType shape = RankedTensorType::get({4}, intTy);
|
|
|
|
auto attr =
|
|
DenseElementsAttr::get(shape, llvm::makeArrayRef({elementValue}))
|
|
.mapValues(boolTy, [](const APInt &x) {
|
|
return x.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1);
|
|
});
|
|
EXPECT_EQ(attr.getNumElements(), 4);
|
|
EXPECT_TRUE(attr.isSplat());
|
|
EXPECT_TRUE(attr.getSplatValue<BoolAttr>().getValue());
|
|
}
|
|
|
|
TEST(DenseSplatMapValuesTest, I32ToFalse) {
|
|
MLIRContext context;
|
|
const int elementValue = 0;
|
|
IntegerType boolTy = IntegerType::get(&context, 1);
|
|
IntegerType intTy = IntegerType::get(&context, 32);
|
|
RankedTensorType shape = RankedTensorType::get({4}, intTy);
|
|
|
|
auto attr =
|
|
DenseElementsAttr::get(shape, llvm::makeArrayRef({elementValue}))
|
|
.mapValues(boolTy, [](const APInt &x) {
|
|
return x.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1);
|
|
});
|
|
EXPECT_EQ(attr.getNumElements(), 4);
|
|
EXPECT_TRUE(attr.isSplat());
|
|
EXPECT_FALSE(attr.getSplatValue<BoolAttr>().getValue());
|
|
}
|
|
} // namespace
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// DenseResourceElementsAttr
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
template <typename AttrT, typename T>
|
|
static void checkNativeAccess(MLIRContext *ctx, ArrayRef<T> data,
|
|
Type elementType) {
|
|
auto type = RankedTensorType::get(data.size(), elementType);
|
|
auto attr = AttrT::get(type, "resource",
|
|
UnmanagedAsmResourceBlob::allocateInferAlign(data));
|
|
|
|
// Check that we can access and iterate the data properly.
|
|
Optional<ArrayRef<T>> attrData = attr.tryGetAsArrayRef();
|
|
EXPECT_TRUE(attrData.has_value());
|
|
EXPECT_EQ(*attrData, data);
|
|
|
|
// Check that we cast to this attribute when possible.
|
|
Attribute genericAttr = attr;
|
|
EXPECT_TRUE(genericAttr.template isa<AttrT>());
|
|
}
|
|
template <typename AttrT, typename T>
|
|
static void checkNativeIntAccess(Builder &builder, size_t intWidth) {
|
|
T data[] = {0, 1, 2};
|
|
checkNativeAccess<AttrT, T>(builder.getContext(), llvm::makeArrayRef(data),
|
|
builder.getIntegerType(intWidth));
|
|
}
|
|
|
|
namespace {
|
|
TEST(DenseResourceElementsAttrTest, CheckNativeAccess) {
|
|
MLIRContext context;
|
|
Builder builder(&context);
|
|
|
|
// Bool
|
|
bool boolData[] = {true, false, true};
|
|
checkNativeAccess<DenseBoolResourceElementsAttr>(
|
|
&context, llvm::makeArrayRef(boolData), builder.getI1Type());
|
|
|
|
// Unsigned integers
|
|
checkNativeIntAccess<DenseUI8ResourceElementsAttr, uint8_t>(builder, 8);
|
|
checkNativeIntAccess<DenseUI16ResourceElementsAttr, uint16_t>(builder, 16);
|
|
checkNativeIntAccess<DenseUI32ResourceElementsAttr, uint32_t>(builder, 32);
|
|
checkNativeIntAccess<DenseUI64ResourceElementsAttr, uint64_t>(builder, 64);
|
|
|
|
// Signed integers
|
|
checkNativeIntAccess<DenseI8ResourceElementsAttr, int8_t>(builder, 8);
|
|
checkNativeIntAccess<DenseI16ResourceElementsAttr, int16_t>(builder, 16);
|
|
checkNativeIntAccess<DenseI32ResourceElementsAttr, int32_t>(builder, 32);
|
|
checkNativeIntAccess<DenseI64ResourceElementsAttr, int64_t>(builder, 64);
|
|
|
|
// Float
|
|
float floatData[] = {0, 1, 2};
|
|
checkNativeAccess<DenseF32ResourceElementsAttr>(
|
|
&context, llvm::makeArrayRef(floatData), builder.getF32Type());
|
|
|
|
// Double
|
|
double doubleData[] = {0, 1, 2};
|
|
checkNativeAccess<DenseF64ResourceElementsAttr>(
|
|
&context, llvm::makeArrayRef(doubleData), builder.getF64Type());
|
|
}
|
|
|
|
TEST(DenseResourceElementsAttrTest, CheckNoCast) {
|
|
MLIRContext context;
|
|
Builder builder(&context);
|
|
|
|
// Create a i32 attribute.
|
|
ArrayRef<uint32_t> data;
|
|
auto type = RankedTensorType::get(data.size(), builder.getI32Type());
|
|
Attribute i32ResourceAttr = DenseI32ResourceElementsAttr::get(
|
|
type, "resource", UnmanagedAsmResourceBlob::allocateInferAlign(data));
|
|
|
|
EXPECT_TRUE(i32ResourceAttr.isa<DenseI32ResourceElementsAttr>());
|
|
EXPECT_FALSE(i32ResourceAttr.isa<DenseF32ResourceElementsAttr>());
|
|
EXPECT_FALSE(i32ResourceAttr.isa<DenseBoolResourceElementsAttr>());
|
|
}
|
|
|
|
TEST(DenseResourceElementsAttrTest, CheckInvalidData) {
|
|
MLIRContext context;
|
|
Builder builder(&context);
|
|
|
|
// Create a bool attribute with data of the incorrect type.
|
|
ArrayRef<uint32_t> data;
|
|
auto type = RankedTensorType::get(data.size(), builder.getI32Type());
|
|
EXPECT_DEBUG_DEATH(
|
|
{
|
|
DenseBoolResourceElementsAttr::get(
|
|
type, "resource",
|
|
UnmanagedAsmResourceBlob::allocateInferAlign(data));
|
|
},
|
|
"alignment mismatch between expected alignment and blob alignment");
|
|
}
|
|
|
|
TEST(DenseResourceElementsAttrTest, CheckInvalidType) {
|
|
MLIRContext context;
|
|
Builder builder(&context);
|
|
|
|
// Create a bool attribute with incorrect type.
|
|
ArrayRef<bool> data;
|
|
auto type = RankedTensorType::get(data.size(), builder.getI32Type());
|
|
EXPECT_DEBUG_DEATH(
|
|
{
|
|
DenseBoolResourceElementsAttr::get(
|
|
type, "resource",
|
|
UnmanagedAsmResourceBlob::allocateInferAlign(data));
|
|
},
|
|
"invalid shape element type for provided type `T`");
|
|
}
|
|
} // namespace
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// SparseElementsAttr
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
TEST(SparseElementsAttrTest, GetZero) {
|
|
MLIRContext context;
|
|
context.allowUnregisteredDialects();
|
|
|
|
IntegerType intTy = IntegerType::get(&context, 32);
|
|
FloatType floatTy = FloatType::getF32(&context);
|
|
Type stringTy = OpaqueType::get(StringAttr::get(&context, "test"), "string");
|
|
|
|
ShapedType tensorI32 = RankedTensorType::get({2, 2}, intTy);
|
|
ShapedType tensorF32 = RankedTensorType::get({2, 2}, floatTy);
|
|
ShapedType tensorString = RankedTensorType::get({2, 2}, stringTy);
|
|
|
|
auto indicesType =
|
|
RankedTensorType::get({1, 2}, IntegerType::get(&context, 64));
|
|
auto indices =
|
|
DenseIntElementsAttr::get(indicesType, {APInt(64, 0), APInt(64, 0)});
|
|
|
|
RankedTensorType intValueTy = RankedTensorType::get({1}, intTy);
|
|
auto intValue = DenseIntElementsAttr::get(intValueTy, {1});
|
|
|
|
RankedTensorType floatValueTy = RankedTensorType::get({1}, floatTy);
|
|
auto floatValue = DenseFPElementsAttr::get(floatValueTy, {1.0f});
|
|
|
|
RankedTensorType stringValueTy = RankedTensorType::get({1}, stringTy);
|
|
auto stringValue = DenseElementsAttr::get(stringValueTy, {StringRef("foo")});
|
|
|
|
auto sparseInt = SparseElementsAttr::get(tensorI32, indices, intValue);
|
|
auto sparseFloat = SparseElementsAttr::get(tensorF32, indices, floatValue);
|
|
auto sparseString =
|
|
SparseElementsAttr::get(tensorString, indices, stringValue);
|
|
|
|
// Only index (0, 0) contains an element, others are supposed to return
|
|
// the zero/empty value.
|
|
auto zeroIntValue =
|
|
sparseInt.getValues<Attribute>()[{1, 1}].cast<IntegerAttr>();
|
|
EXPECT_EQ(zeroIntValue.getInt(), 0);
|
|
EXPECT_TRUE(zeroIntValue.getType() == intTy);
|
|
|
|
auto zeroFloatValue =
|
|
sparseFloat.getValues<Attribute>()[{1, 1}].cast<FloatAttr>();
|
|
EXPECT_EQ(zeroFloatValue.getValueAsDouble(), 0.0f);
|
|
EXPECT_TRUE(zeroFloatValue.getType() == floatTy);
|
|
|
|
auto zeroStringValue =
|
|
sparseString.getValues<Attribute>()[{1, 1}].cast<StringAttr>();
|
|
EXPECT_TRUE(zeroStringValue.getValue().empty());
|
|
EXPECT_TRUE(zeroStringValue.getType() == stringTy);
|
|
}
|
|
|
|
} // namespace
|