2019-06-12 07:14:17 +08:00
|
|
|
//===- AttributeTest.cpp - Attribute unit tests ---------------------------===//
|
|
|
|
//
|
2020-01-26 11:58:30 +08:00
|
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
2019-12-24 01:35:36 +08:00
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
2019-06-12 07:14:17 +08:00
|
|
|
//
|
2019-12-24 01:35:36 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2019-06-12 07:14:17 +08:00
|
|
|
|
|
|
|
#include "mlir/IR/Attributes.h"
|
2020-04-27 04:52:50 +08:00
|
|
|
#include "mlir/IR/Identifier.h"
|
2019-06-12 07:14:17 +08:00
|
|
|
#include "mlir/IR/StandardTypes.h"
|
|
|
|
#include "gtest/gtest.h"
|
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
using namespace mlir::detail;
|
|
|
|
|
|
|
|
template <typename EltTy>
|
|
|
|
static void testSplat(Type eltType, const EltTy &splatElt) {
|
2020-04-27 04:52:50 +08:00
|
|
|
RankedTensorType shape = RankedTensorType::get({2, 1}, eltType);
|
2019-06-12 07:14:17 +08:00
|
|
|
|
|
|
|
// Check that the generated splat is the same for 1 element and N elements.
|
|
|
|
DenseElementsAttr splat = DenseElementsAttr::get(shape, splatElt);
|
|
|
|
EXPECT_TRUE(splat.isSplat());
|
|
|
|
|
|
|
|
auto detectedSplat =
|
|
|
|
DenseElementsAttr::get(shape, llvm::makeArrayRef({splatElt, splatElt}));
|
|
|
|
EXPECT_EQ(detectedSplat, splat);
|
2020-05-06 03:39:12 +08:00
|
|
|
|
|
|
|
for (auto newValue : detectedSplat.template getValues<EltTy>())
|
2020-05-06 03:39:22 +08:00
|
|
|
EXPECT_TRUE(newValue == splatElt);
|
2019-06-12 07:14:17 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
TEST(DenseSplatTest, BoolSplat) {
|
|
|
|
MLIRContext context;
|
|
|
|
IntegerType boolTy = IntegerType::get(1, &context);
|
2020-04-27 04:52:50 +08:00
|
|
|
RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy);
|
2019-06-12 07:14:17 +08:00
|
|
|
|
|
|
|
// Check that splat is automatically detected for boolean values.
|
|
|
|
/// True.
|
2019-06-14 23:41:32 +08:00
|
|
|
DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true);
|
2019-06-12 07:14:17 +08:00
|
|
|
EXPECT_TRUE(trueSplat.isSplat());
|
|
|
|
/// False.
|
2019-06-14 23:41:32 +08:00
|
|
|
DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false);
|
2019-06-12 07:14:17 +08:00
|
|
|
EXPECT_TRUE(falseSplat.isSplat());
|
|
|
|
EXPECT_NE(falseSplat, trueSplat);
|
|
|
|
|
|
|
|
/// Detect and handle splat within 8 elements (bool values are bit-packed).
|
|
|
|
/// True.
|
2019-06-14 23:41:32 +08:00
|
|
|
auto detectedSplat = DenseElementsAttr::get(shape, {true, true, true, true});
|
2019-06-12 07:14:17 +08:00
|
|
|
EXPECT_EQ(detectedSplat, trueSplat);
|
|
|
|
/// False.
|
2019-06-14 23:41:32 +08:00
|
|
|
detectedSplat = DenseElementsAttr::get(shape, {false, false, false, false});
|
2019-06-12 07:14:17 +08:00
|
|
|
EXPECT_EQ(detectedSplat, falseSplat);
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST(DenseSplatTest, LargeBoolSplat) {
|
2019-06-14 07:09:56 +08:00
|
|
|
constexpr int64_t boolCount = 56;
|
2019-06-12 07:14:17 +08:00
|
|
|
|
|
|
|
MLIRContext context;
|
|
|
|
IntegerType boolTy = IntegerType::get(1, &context);
|
2020-04-27 04:52:50 +08:00
|
|
|
RankedTensorType shape = RankedTensorType::get({boolCount}, boolTy);
|
2019-06-12 07:14:17 +08:00
|
|
|
|
|
|
|
// Check that splat is automatically detected for boolean values.
|
|
|
|
/// True.
|
|
|
|
DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true);
|
|
|
|
DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false);
|
|
|
|
EXPECT_TRUE(trueSplat.isSplat());
|
|
|
|
EXPECT_TRUE(falseSplat.isSplat());
|
|
|
|
|
|
|
|
/// Detect that the large boolean arrays are properly splatted.
|
|
|
|
/// True.
|
|
|
|
SmallVector<bool, 64> trueValues(boolCount, true);
|
|
|
|
auto detectedSplat = DenseElementsAttr::get(shape, trueValues);
|
|
|
|
EXPECT_EQ(detectedSplat, trueSplat);
|
|
|
|
/// False.
|
|
|
|
SmallVector<bool, 64> falseValues(boolCount, false);
|
|
|
|
detectedSplat = DenseElementsAttr::get(shape, falseValues);
|
|
|
|
EXPECT_EQ(detectedSplat, falseSplat);
|
|
|
|
}
|
|
|
|
|
2019-06-18 10:46:31 +08:00
|
|
|
TEST(DenseSplatTest, BoolNonSplat) {
|
|
|
|
MLIRContext context;
|
|
|
|
IntegerType boolTy = IntegerType::get(1, &context);
|
2020-04-27 04:52:50 +08:00
|
|
|
RankedTensorType shape = RankedTensorType::get({6}, boolTy);
|
2019-06-18 10:46:31 +08:00
|
|
|
|
|
|
|
// Check that we properly handle non-splat values.
|
|
|
|
DenseElementsAttr nonSplat =
|
|
|
|
DenseElementsAttr::get(shape, {false, false, true, false, false, true});
|
|
|
|
EXPECT_FALSE(nonSplat.isSplat());
|
|
|
|
}
|
|
|
|
|
2019-06-12 07:14:17 +08:00
|
|
|
TEST(DenseSplatTest, OddIntSplat) {
|
|
|
|
// Test detecting a splat with an odd(non 8-bit) integer bitwidth.
|
|
|
|
MLIRContext context;
|
|
|
|
constexpr size_t intWidth = 19;
|
|
|
|
IntegerType intTy = IntegerType::get(intWidth, &context);
|
|
|
|
APInt value(intWidth, 10);
|
|
|
|
|
|
|
|
testSplat(intTy, value);
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST(DenseSplatTest, Int32Splat) {
|
|
|
|
MLIRContext context;
|
|
|
|
IntegerType intTy = IntegerType::get(32, &context);
|
|
|
|
int value = 64;
|
|
|
|
|
|
|
|
testSplat(intTy, value);
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST(DenseSplatTest, IntAttrSplat) {
|
|
|
|
MLIRContext context;
|
|
|
|
IntegerType intTy = IntegerType::get(85, &context);
|
|
|
|
Attribute value = IntegerAttr::get(intTy, 109);
|
|
|
|
|
|
|
|
testSplat(intTy, value);
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST(DenseSplatTest, F32Splat) {
|
|
|
|
MLIRContext context;
|
|
|
|
FloatType floatTy = FloatType::getF32(&context);
|
|
|
|
float value = 10.0;
|
|
|
|
|
|
|
|
testSplat(floatTy, value);
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST(DenseSplatTest, F64Splat) {
|
|
|
|
MLIRContext context;
|
|
|
|
FloatType floatTy = FloatType::getF64(&context);
|
|
|
|
double value = 10.0;
|
|
|
|
|
|
|
|
testSplat(floatTy, APFloat(value));
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST(DenseSplatTest, FloatAttrSplat) {
|
|
|
|
MLIRContext context;
|
|
|
|
FloatType floatTy = FloatType::getBF16(&context);
|
|
|
|
Attribute value = FloatAttr::get(floatTy, 10.0);
|
|
|
|
|
|
|
|
testSplat(floatTy, value);
|
|
|
|
}
|
2020-01-10 06:41:49 +08:00
|
|
|
|
|
|
|
TEST(DenseSplatTest, BF16Splat) {
|
|
|
|
MLIRContext context;
|
|
|
|
FloatType floatTy = FloatType::getBF16(&context);
|
|
|
|
// Note: We currently use double to represent bfloat16.
|
|
|
|
double value = 10.0;
|
|
|
|
|
|
|
|
testSplat(floatTy, value);
|
|
|
|
}
|
|
|
|
|
2020-04-27 04:52:50 +08:00
|
|
|
TEST(DenseSplatTest, StringSplat) {
|
|
|
|
MLIRContext context;
|
|
|
|
Type stringType =
|
|
|
|
OpaqueType::get(Identifier::get("test", &context), "string", &context);
|
|
|
|
StringRef value = "test-string";
|
|
|
|
testSplat(stringType, value);
|
|
|
|
}
|
|
|
|
|
2020-05-02 07:26:45 +08:00
|
|
|
TEST(DenseSplatTest, StringAttrSplat) {
|
|
|
|
MLIRContext context;
|
|
|
|
Type stringType =
|
|
|
|
OpaqueType::get(Identifier::get("test", &context), "string", &context);
|
|
|
|
Attribute stringAttr = StringAttr::get("test-string", stringType);
|
|
|
|
testSplat(stringType, stringAttr);
|
|
|
|
}
|
|
|
|
|
2020-05-06 03:39:12 +08:00
|
|
|
TEST(DenseComplexTest, ComplexFloatSplat) {
|
|
|
|
MLIRContext context;
|
|
|
|
ComplexType complexType = ComplexType::get(FloatType::getF32(&context));
|
|
|
|
std::complex<float> value(10.0, 15.0);
|
|
|
|
testSplat(complexType, value);
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST(DenseComplexTest, ComplexIntSplat) {
|
|
|
|
MLIRContext context;
|
|
|
|
ComplexType complexType = ComplexType::get(IntegerType::get(64, &context));
|
|
|
|
std::complex<int64_t> value(10, 15);
|
|
|
|
testSplat(complexType, value);
|
|
|
|
}
|
|
|
|
|
2020-05-06 03:39:22 +08:00
|
|
|
TEST(DenseComplexTest, ComplexAPFloatSplat) {
|
|
|
|
MLIRContext context;
|
|
|
|
ComplexType complexType = ComplexType::get(FloatType::getF32(&context));
|
|
|
|
std::complex<APFloat> value(APFloat(10.0f), APFloat(15.0f));
|
|
|
|
testSplat(complexType, value);
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST(DenseComplexTest, ComplexAPIntSplat) {
|
|
|
|
MLIRContext context;
|
|
|
|
ComplexType complexType = ComplexType::get(IntegerType::get(64, &context));
|
|
|
|
std::complex<APInt> value(APInt(64, 10), APInt(64, 15));
|
|
|
|
testSplat(complexType, value);
|
|
|
|
}
|
|
|
|
|
2019-06-12 07:14:17 +08:00
|
|
|
} // end namespace
|