Introduce a new Dense Array attribute

This attribute is similar to DenseElementsAttr but does not support
splat. As such it has a much simpler API and does not need any smart
iterator: it exposes direct ArrayRef access.

A new syntax is introduced so that the generic printing/parsing looks
like:

  [:i64 1, -2, 3]

This attribute beings like an ArrayAttr but has a `:` token after the
opening square brace to introduce the element type (supported are I8,
I16, I32, I64, F32, F64) and the comma separated list for the data.

This is particularly convenient for attributes intended to be small,
like those referring to shapes.
For example a `transpose` operation with a `dims` attribute could be
defined as such:

  let arguments = (ins AnyTensor:$input, DenseI64ArrayAttr:$dims);
  let assemblyFormat = "$input `dims` `=` $dims attr-dict : type($input)";

And printed this way (the element type is elided in this case):

  transpose %input dims = [0, 2, 1] : tensor<2x3x4xf32>

The C++ API for dims would just directly return an ArrayRef<int64>

RFC: https://discourse.llvm.org/t/rfc-introduce-a-new-dense-array-attribute/63279

Recommit with a custom DenseArrayBaseAttrStorage class to ensure
over-alignment of the storage to the largest type.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D123774
This commit is contained in:
Mehdi Amini 2022-06-28 11:29:27 +00:00
parent e2f313df8f
commit 7faf75bb3e
12 changed files with 618 additions and 13 deletions

View File

@ -66,8 +66,8 @@ template <typename T>
struct is_complex_t<std::complex<T>> : public std::true_type {};
} // namespace detail
/// An attribute that represents a reference to a dense vector or tensor object.
///
/// An attribute that represents a reference to a dense vector or tensor
/// object.
class DenseElementsAttr : public Attribute {
public:
using Attribute::Attribute;
@ -743,6 +743,55 @@ public:
//===----------------------------------------------------------------------===//
namespace mlir {
namespace detail {
/// Base class for DenseArrayAttr that is instantiated and specialized for each
/// supported element type below.
template <typename T>
class DenseArrayAttr : public DenseArrayBaseAttr {
public:
using DenseArrayBaseAttr::DenseArrayBaseAttr;
/// Implicit conversion to ArrayRef<T>.
operator ArrayRef<T>() const;
ArrayRef<T> asArrayRef() { return ArrayRef<T>{*this}; }
/// Builder from ArrayRef<T>.
static DenseArrayAttr get(MLIRContext *context, ArrayRef<T> content);
/// Print the short form `[42, 100, -1]` without any type prefix.
void print(AsmPrinter &printer) const;
void print(raw_ostream &os) const;
/// Print the short form `42, 100, -1` without any braces or type prefix.
void printWithoutBraces(raw_ostream &os) const;
/// Parse the short form `[42, 100, -1]` without any type prefix.
static Attribute parse(AsmParser &parser, Type odsType);
/// Parse the short form `42, 100, -1` without any type prefix or braces.
static Attribute parseWithoutBraces(AsmParser &parser, Type odsType);
/// Support for isa<>/cast<>.
static bool classof(Attribute attr);
};
template <>
void DenseArrayAttr<int8_t>::printWithoutBraces(raw_ostream &os) const;
extern template class DenseArrayAttr<int8_t>;
extern template class DenseArrayAttr<int16_t>;
extern template class DenseArrayAttr<int32_t>;
extern template class DenseArrayAttr<int64_t>;
extern template class DenseArrayAttr<float>;
extern template class DenseArrayAttr<double>;
} // namespace detail
// Public name for all the supported DenseArrayAttr
using DenseI8ArrayAttr = detail::DenseArrayAttr<int8_t>;
using DenseI16ArrayAttr = detail::DenseArrayAttr<int16_t>;
using DenseI32ArrayAttr = detail::DenseArrayAttr<int32_t>;
using DenseI64ArrayAttr = detail::DenseArrayAttr<int64_t>;
using DenseF32ArrayAttr = detail::DenseArrayAttr<float>;
using DenseF64ArrayAttr = detail::DenseArrayAttr<double>;
//===----------------------------------------------------------------------===//
// BoolAttr
//===----------------------------------------------------------------------===//

View File

@ -144,6 +144,78 @@ def Builtin_ArrayAttr : Builtin_Attr<"Array", [
// DenseIntOrFPElementsAttr
//===----------------------------------------------------------------------===//
def Builtin_DenseArrayBase : Builtin_Attr<
"DenseArrayBase", [ElementsAttrInterface]> {
let summary = "A dense array of i8, i16, i32, i64, f32, or f64.";
let description = [{
A dense array attribute is an attribute that represents a dense array of
primitive element types. Contrary to DenseIntOrFPElementsAttr this is a
flat unidimensional array which does not have a storage optimization for
splat. This allows to expose the raw array through a C++ API as
`ArrayRef<T>`. This is the base class attribute, the actual access is
intended to be managed through the subclasses `DenseI8ArrayAttr`,
`DenseI16ArrayAttr`, `DenseI32ArrayAttr`, `DenseI64ArrayAttr`,
`DenseF32ArrayAttr`, and `DenseF64ArrayAttr`.
Syntax:
```
dense-array-attribute ::= `[` `:` (integer-type | float-type) tensor-literal `]`
```
Examples:
```mlir
[:i8]
[:i32 10, 42]
[:f64 42., 12.]
```
when a specific subclass is used as argument of an operation, the declarative
assembly will omit the type and print directly:
```
[1, 2, 3]
```
}];
let parameters = (ins AttributeSelfTypeParameter<"", "ShapedType">:$type,
"DenseArrayBaseAttr::EltType":$eltType,
ArrayRefParameter<"char">:$elements);
let extraClassDeclaration = [{
// All possible supported element type.
enum class EltType { I8, I16, I32, I64, F32, F64 };
/// Allow implicit conversion to ElementsAttr.
operator ElementsAttr() const {
return *this ? cast<ElementsAttr>() : nullptr;
}
/// ElementsAttr implementation.
using ContiguousIterableTypesT =
std::tuple<int8_t, int16_t, int32_t, int64_t, float, double>;
const int8_t *value_begin_impl(OverloadToken<int8_t>) const;
const int16_t *value_begin_impl(OverloadToken<int16_t>) const;
const int32_t *value_begin_impl(OverloadToken<int32_t>) const;
const int64_t *value_begin_impl(OverloadToken<int64_t>) const;
const float *value_begin_impl(OverloadToken<float>) const;
const double *value_begin_impl(OverloadToken<double>) const;
/// Methods to support type inquiry through isa, cast, and dyn_cast.
EltType getElementType() const;
/// Printer for the short form: will dispatch to the appropriate subclass.
void print(AsmPrinter &printer) const;
void print(raw_ostream &os) const;
/// Print the short form `42, 100, -1` without any braces or prefix.
void printWithoutBraces(raw_ostream &os) const;
}];
// Do not generate the storage class, we need to handle custom storage alignment.
let genStorageClass = 0;
let genAccessors = 0;
let skipDefaultBuilders = 1;
}
//===----------------------------------------------------------------------===//
// DenseIntOrFPElementsAttr
//===----------------------------------------------------------------------===//
def Builtin_DenseIntOrFPElementsAttr : Builtin_Attr<
"DenseIntOrFPElements", [ElementsAttrInterface], "DenseElementsAttr"
> {

View File

@ -1258,6 +1258,19 @@ class IntElementsAttrBase<Pred condition, string summary> :
let convertFromStorage = "$_self";
}
class DenseArrayAttrBase<string denseAttrName, string cppType, string summaryName> :
ElementsAttrBase<CPred<"$_self.isa<::mlir::" # denseAttrName # ">()">,
summaryName # " dense array attribute"> {
let storageType = "::mlir::" # denseAttrName;
let returnType = "::llvm::ArrayRef<" # cppType # ">";
}
def DenseI8ArrayAttr : DenseArrayAttrBase<"DenseI8ArrayAttr", "int8_t", "i8">;
def DenseI16ArrayAttr : DenseArrayAttrBase<"DenseI16ArrayAttr", "int16_t", "i16">;
def DenseI32ArrayAttr : DenseArrayAttrBase<"DenseI32ArrayAttr", "int32_t", "i32">;
def DenseI64ArrayAttr : DenseArrayAttrBase<"DenseI64ArrayAttr", "int64_t", "i64">;
def DenseF32ArrayAttr : DenseArrayAttrBase<"DenseF32ArrayAttr", "float", "f32">;
def DenseF64ArrayAttr : DenseArrayAttrBase<"DenseF64ArrayAttr", "double", "f64">;
def IndexElementsAttr
: IntElementsAttrBase<CPred<[{$_self.cast<::mlir::DenseIntElementsAttr>()
.getType()

View File

@ -1878,9 +1878,34 @@ void AsmPrinter::Impl::printAttribute(Attribute attr,
}
os << '>';
}
} else if (auto denseArrayAttr = attr.dyn_cast<DenseArrayBaseAttr>()) {
typeElision = AttrTypeElision::Must;
switch (denseArrayAttr.getElementType()) {
case DenseArrayBaseAttr::EltType::I8:
os << "[:i8 ";
break;
case DenseArrayBaseAttr::EltType::I16:
os << "[:i16 ";
break;
case DenseArrayBaseAttr::EltType::I32:
os << "[:i32 ";
break;
case DenseArrayBaseAttr::EltType::I64:
os << "[:i64 ";
break;
case DenseArrayBaseAttr::EltType::F32:
os << "[:f32 ";
break;
case DenseArrayBaseAttr::EltType::F64:
os << "[:f64 ";
break;
}
denseArrayAttr.printWithoutBraces(os);
os << "]";
} else if (auto locAttr = attr.dyn_cast<LocationAttr>()) {
printLocation(locAttr);
} else {
llvm::report_fatal_error("Unknown builtin attribute");
}
// Don't print the type if we must elide it, or if it is a None type.
if (typeElision != AttrTypeElision::Must && !attrType.isa<NoneType>()) {

View File

@ -12,6 +12,7 @@
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Types.h"
@ -35,11 +36,11 @@ using namespace mlir::detail;
//===----------------------------------------------------------------------===//
void BuiltinDialect::registerAttributes() {
addAttributes<AffineMapAttr, ArrayAttr, DenseIntOrFPElementsAttr,
DenseStringElementsAttr, DictionaryAttr, FloatAttr,
SymbolRefAttr, IntegerAttr, IntegerSetAttr, OpaqueAttr,
OpaqueElementsAttr, SparseElementsAttr, StringAttr, TypeAttr,
UnitAttr>();
addAttributes<AffineMapAttr, ArrayAttr, DenseArrayBaseAttr,
DenseIntOrFPElementsAttr, DenseStringElementsAttr,
DictionaryAttr, FloatAttr, SymbolRefAttr, IntegerAttr,
IntegerSetAttr, OpaqueAttr, OpaqueElementsAttr,
SparseElementsAttr, StringAttr, TypeAttr, UnitAttr>();
}
//===----------------------------------------------------------------------===//
@ -664,6 +665,274 @@ DenseElementsAttr::ComplexIntElementIterator::operator*() const {
readBits(getData(), offset + storageWidth, bitWidth)};
}
//===----------------------------------------------------------------------===//
// DenseArrayAttr
//===----------------------------------------------------------------------===//
/// Custom storage to ensure proper memory alignment for the allocation of
/// DenseArray of any element type.
struct ::mlir::detail::DenseArrayBaseAttrStorage : public AttributeStorage {
using KeyTy = std::tuple<ShapedType, DenseArrayBaseAttr::EltType,
::llvm::ArrayRef<char>>;
DenseArrayBaseAttrStorage(ShapedType type,
DenseArrayBaseAttr::EltType eltType,
::llvm::ArrayRef<char> elements)
: AttributeStorage(type), eltType(eltType), elements(elements) {}
bool operator==(const KeyTy &tblgenKey) const {
return (getType() == std::get<0>(tblgenKey)) &&
(eltType == std::get<1>(tblgenKey)) &&
(elements == std::get<2>(tblgenKey));
}
static ::llvm::hash_code hashKey(const KeyTy &tblgenKey) {
return ::llvm::hash_combine(std::get<0>(tblgenKey), std::get<1>(tblgenKey),
std::get<2>(tblgenKey));
}
static DenseArrayBaseAttrStorage *
construct(AttributeStorageAllocator &allocator, const KeyTy &tblgenKey) {
auto type = std::get<0>(tblgenKey);
auto eltType = std::get<1>(tblgenKey);
auto elements = std::get<2>(tblgenKey);
if (!elements.empty()) {
char *alloc = static_cast<char *>(
allocator.allocate(elements.size(), alignof(uint64_t)));
std::uninitialized_copy(elements.begin(), elements.end(), alloc);
elements = ArrayRef<char>(alloc, elements.size());
}
return new (allocator.allocate<DenseArrayBaseAttrStorage>())
DenseArrayBaseAttrStorage(type, eltType, elements);
}
DenseArrayBaseAttr::EltType eltType;
::llvm::ArrayRef<char> elements;
};
DenseArrayBaseAttr::EltType DenseArrayBaseAttr::getElementType() const {
return getImpl()->eltType;
}
const int8_t *
DenseArrayBaseAttr::value_begin_impl(OverloadToken<int8_t>) const {
return cast<DenseI8ArrayAttr>().asArrayRef().begin();
}
const int16_t *
DenseArrayBaseAttr::value_begin_impl(OverloadToken<int16_t>) const {
return cast<DenseI16ArrayAttr>().asArrayRef().begin();
}
const int32_t *
DenseArrayBaseAttr::value_begin_impl(OverloadToken<int32_t>) const {
return cast<DenseI32ArrayAttr>().asArrayRef().begin();
}
const int64_t *
DenseArrayBaseAttr::value_begin_impl(OverloadToken<int64_t>) const {
return cast<DenseI64ArrayAttr>().asArrayRef().begin();
}
const float *DenseArrayBaseAttr::value_begin_impl(OverloadToken<float>) const {
return cast<DenseF32ArrayAttr>().asArrayRef().begin();
}
const double *
DenseArrayBaseAttr::value_begin_impl(OverloadToken<double>) const {
return cast<DenseF64ArrayAttr>().asArrayRef().begin();
}
void DenseArrayBaseAttr::print(AsmPrinter &printer) const {
print(printer.getStream());
}
void DenseArrayBaseAttr::printWithoutBraces(raw_ostream &os) const {
switch (getElementType()) {
case DenseArrayBaseAttr::EltType::I8:
this->cast<DenseI8ArrayAttr>().printWithoutBraces(os);
return;
case DenseArrayBaseAttr::EltType::I16:
this->cast<DenseI16ArrayAttr>().printWithoutBraces(os);
return;
case DenseArrayBaseAttr::EltType::I32:
this->cast<DenseI32ArrayAttr>().printWithoutBraces(os);
return;
case DenseArrayBaseAttr::EltType::I64:
this->cast<DenseI64ArrayAttr>().printWithoutBraces(os);
return;
case DenseArrayBaseAttr::EltType::F32:
this->cast<DenseF32ArrayAttr>().printWithoutBraces(os);
return;
case DenseArrayBaseAttr::EltType::F64:
this->cast<DenseF64ArrayAttr>().printWithoutBraces(os);
return;
}
llvm_unreachable("<unknown DenseArrayBaseAttr>");
}
void DenseArrayBaseAttr::print(raw_ostream &os) const {
os << "[";
printWithoutBraces(os);
os << "]";
}
template <typename T>
void DenseArrayAttr<T>::print(AsmPrinter &printer) const {
print(printer.getStream());
}
template <typename T>
void DenseArrayAttr<T>::printWithoutBraces(raw_ostream &os) const {
ArrayRef<T> values{*this};
llvm::interleaveComma(values, os);
}
/// Specialization for int8_t for forcing printing as number instead of chars.
template <>
void DenseArrayAttr<int8_t>::printWithoutBraces(raw_ostream &os) const {
ArrayRef<int8_t> values{*this};
llvm::interleaveComma(values, os, [&](int64_t v) { os << v; });
}
template <typename T>
void DenseArrayAttr<T>::print(raw_ostream &os) const {
os << "[";
printWithoutBraces(os);
os << "]";
}
/// Parse a single element: generic template for int types, specialized for
/// floating points below.
template <typename T>
static ParseResult parseDenseArrayAttrElt(AsmParser &parser, T &value) {
return parser.parseInteger(value);
}
template <>
ParseResult parseDenseArrayAttrElt<float>(AsmParser &parser, float &value) {
double doubleVal;
if (parser.parseFloat(doubleVal))
return failure();
value = doubleVal;
return success();
}
template <>
ParseResult parseDenseArrayAttrElt<double>(AsmParser &parser, double &value) {
return parser.parseFloat(value);
}
/// Parse a DenseArrayAttr without the braces: `1, 2, 3`
template <typename T>
Attribute DenseArrayAttr<T>::parseWithoutBraces(AsmParser &parser,
Type odsType) {
SmallVector<T> data;
if (failed(parser.parseCommaSeparatedList([&]() {
T value;
if (parseDenseArrayAttrElt(parser, value))
return failure();
data.push_back(value);
return success();
})))
return {};
return get(parser.getContext(), data);
}
/// Parse a DenseArrayAttr: `[ 1, 2, 3 ]`
template <typename T>
Attribute DenseArrayAttr<T>::parse(AsmParser &parser, Type odsType) {
if (parser.parseLSquare())
return {};
Attribute result = parseWithoutBraces(parser, odsType);
if (parser.parseRSquare())
return {};
return result;
}
/// Conversion from DenseArrayAttr<T> to ArrayRef<T>.
template <typename T>
DenseArrayAttr<T>::operator ArrayRef<T>() const {
ArrayRef<char> raw = getImpl()->elements;
assert((raw.size() % sizeof(T)) == 0);
return ArrayRef<T>(reinterpret_cast<const T *>(raw.data()),
raw.size() / sizeof(T));
}
namespace {
/// Mapping from C++ element type to MLIR DenseArrayAttr internals.
template <typename T>
struct denseArrayAttrEltTypeBuilder;
template <>
struct denseArrayAttrEltTypeBuilder<int8_t> {
constexpr static auto eltType = DenseArrayBaseAttr::EltType::I8;
static ShapedType getShapedType(MLIRContext *context, int64_t shape) {
return VectorType::get(shape, IntegerType::get(context, 8));
}
};
template <>
struct denseArrayAttrEltTypeBuilder<int16_t> {
constexpr static auto eltType = DenseArrayBaseAttr::EltType::I16;
static ShapedType getShapedType(MLIRContext *context, int64_t shape) {
return VectorType::get(shape, IntegerType::get(context, 16));
}
};
template <>
struct denseArrayAttrEltTypeBuilder<int32_t> {
constexpr static auto eltType = DenseArrayBaseAttr::EltType::I32;
static ShapedType getShapedType(MLIRContext *context, int64_t shape) {
return VectorType::get(shape, IntegerType::get(context, 32));
}
};
template <>
struct denseArrayAttrEltTypeBuilder<int64_t> {
constexpr static auto eltType = DenseArrayBaseAttr::EltType::I64;
static ShapedType getShapedType(MLIRContext *context, int64_t shape) {
return VectorType::get(shape, IntegerType::get(context, 64));
}
};
template <>
struct denseArrayAttrEltTypeBuilder<float> {
constexpr static auto eltType = DenseArrayBaseAttr::EltType::F32;
static ShapedType getShapedType(MLIRContext *context, int64_t shape) {
return VectorType::get(shape, Float32Type::get(context));
}
};
template <>
struct denseArrayAttrEltTypeBuilder<double> {
constexpr static auto eltType = DenseArrayBaseAttr::EltType::F64;
static ShapedType getShapedType(MLIRContext *context, int64_t shape) {
return VectorType::get(shape, Float64Type::get(context));
}
};
} // namespace
/// Builds a DenseArrayAttr<T> from an ArrayRef<T>.
template <typename T>
DenseArrayAttr<T> DenseArrayAttr<T>::get(MLIRContext *context,
ArrayRef<T> content) {
auto shapedType =
denseArrayAttrEltTypeBuilder<T>::getShapedType(context, content.size());
auto eltType = denseArrayAttrEltTypeBuilder<T>::eltType;
auto rawArray = ArrayRef<char>(reinterpret_cast<const char *>(content.data()),
content.size() * sizeof(T));
return Base::get(context, shapedType, eltType, rawArray)
.template cast<DenseArrayAttr<T>>();
}
template <typename T>
bool DenseArrayAttr<T>::classof(Attribute attr) {
return attr.isa<DenseArrayBaseAttr>() &&
attr.cast<DenseArrayBaseAttr>().getElementType() ==
denseArrayAttrEltTypeBuilder<T>::eltType;
}
namespace mlir {
namespace detail {
// Explicit instantiation for all the supported DenseArrayAttr.
template class DenseArrayAttr<int8_t>;
template class DenseArrayAttr<int16_t>;
template class DenseArrayAttr<int32_t>;
template class DenseArrayAttr<int64_t>;
template class DenseArrayAttr<float>;
template class DenseArrayAttr<double>;
} // namespace detail
} // namespace mlir
//===----------------------------------------------------------------------===//
// DenseElementsAttr
//===----------------------------------------------------------------------===//

View File

@ -11,9 +11,12 @@
//===----------------------------------------------------------------------===//
#include "Parser.h"
#include "AsmParserImpl.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/Parser/AsmParserState.h"
#include "llvm/ADT/StringExtras.h"
@ -30,6 +33,7 @@ using namespace mlir::detail;
/// | float-literal (`:` float-type)?
/// | string-literal (`:` type)?
/// | type
/// | `[` `:` (integer-type | float-type) tensor-literal `]`
/// | `[` (attribute-value (`,` attribute-value)*)? `]`
/// | `{` (attribute-entry (`,` attribute-entry)*)? `}`
/// | symbol-ref-id (`::` symbol-ref-id)*
@ -67,13 +71,16 @@ Attribute Parser::parseAttribute(Type type) {
// Parse an array attribute.
case Token::l_square: {
consumeToken(Token::l_square);
if (consumeIf(Token::colon))
return parseDenseArrayAttr();
SmallVector<Attribute, 4> elements;
auto parseElt = [&]() -> ParseResult {
elements.push_back(parseAttribute());
return elements.back() ? success() : failure();
};
if (parseCommaSeparatedList(Delimiter::Square, parseElt))
if (parseCommaSeparatedListUntil(Token::r_square, parseElt))
return nullptr;
return builder.getArrayAttr(elements);
}
@ -812,6 +819,66 @@ ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) {
// ElementsAttr Parser
//===----------------------------------------------------------------------===//
namespace {
/// This class provides an implementation of AsmParser, allowing to call back
/// into the libMLIRIR-provided APIs for invoking attribute parsing code defined
/// in libMLIRIR.
class CustomAsmParser : public AsmParserImpl<AsmParser> {
public:
CustomAsmParser(Parser &parser)
: AsmParserImpl<AsmParser>(parser.getToken().getLoc(), parser) {}
};
} // namespace
/// Parse a dense array attribute.
Attribute Parser::parseDenseArrayAttr() {
auto typeLoc = getToken().getLoc();
auto type = parseType();
if (!type)
return {};
CustomAsmParser parser(*this);
Attribute result;
if (auto intType = type.dyn_cast<IntegerType>()) {
switch (type.getIntOrFloatBitWidth()) {
case 8:
result = DenseI8ArrayAttr::parseWithoutBraces(parser, Type{});
break;
case 16:
result = DenseI16ArrayAttr::parseWithoutBraces(parser, Type{});
break;
case 32:
result = DenseI32ArrayAttr::parseWithoutBraces(parser, Type{});
break;
case 64:
result = DenseI64ArrayAttr::parseWithoutBraces(parser, Type{});
break;
default:
emitError(typeLoc, "expected i8, i16, i32, or i64 but got: ") << type;
return {};
}
} else if (auto floatType = type.dyn_cast<FloatType>()) {
switch (type.getIntOrFloatBitWidth()) {
case 32:
result = DenseF32ArrayAttr::parseWithoutBraces(parser, Type{});
break;
case 64:
result = DenseF64ArrayAttr::parseWithoutBraces(parser, Type{});
break;
default:
emitError(typeLoc, "expected f32 or f64 but got: ") << type;
return {};
}
} else {
emitError(typeLoc, "expected integer or float type, got: ") << type;
return {};
}
if (!consumeIf(Token::r_square)) {
emitError("expected ']' to close an array attribute");
return {};
}
return result;
}
/// Parse a dense elements attribute.
Attribute Parser::parseDenseElementsAttr(Type attrType) {
auto attribLoc = getToken().getLoc();

View File

@ -264,6 +264,9 @@ public:
Attribute parseDenseElementsAttr(Type attrType);
ShapedType parseElementsLiteralType(Type type);
/// Parse a DenseArrayAttr.
Attribute parseDenseArrayAttr();
/// Parse a sparse elements attribute.
Attribute parseSparseElementsAttr(Type attrType);

View File

@ -513,6 +513,45 @@ func.func @simple_scalar_example() {
return
}
// -----
//===----------------------------------------------------------------------===//
// Test DenseArrayAttr
//===----------------------------------------------------------------------===//
// CHECK-LABEL: func @dense_array_attr
func.func @dense_array_attr() attributes{
// CHECK-SAME: f32attr = [:f32 1.024000e+03, 4.530000e+02, -6.435000e+03],
f32attr = [:f32 1024., 453., -6435.],
// CHECK-SAME: f64attr = [:f64 -1.420000e+02],
f64attr = [:f64 -142.],
// CHECK-SAME: i16attr = [:i16 3, 5, -4, 10],
i16attr = [:i16 3, 5, -4, 10],
// CHECK-SAME: i32attr = [:i32 1024, 453, -6435],
i32attr = [:i32 1024, 453, -6435],
// CHECK-SAME: i64attr = [:i64 -142],
i64attr = [:i64 -142],
// CHECK-SAME: i8attr = [:i8 1, -2, 3]
i8attr = [:i8 1, -2, 3]
} {
// CHECK: test.dense_array_attr
test.dense_array_attr
// CHECK-SAME: i8attr = [1, -2, 3]
i8attr = [1, -2, 3]
// CHECK-SAME: i16attr = [3, 5, -4, 10]
i16attr = [3, 5, -4, 10]
// CHECK-SAME: i32attr = [1024, 453, -6435]
i32attr = [1024, 453, -6435]
// CHECK-SAME: i64attr = [-142]
i64attr = [-142]
// CHECK-SAME: f32attr = [1.024000e+03, 4.530000e+02, -6.435000e+03]
f32attr = [1024., 453., -6435.]
// CHECK-SAME: f64attr = [-1.420000e+02]
f64attr = [-142.]
return
}
// -----
//===----------------------------------------------------------------------===//

View File

@ -5,23 +5,40 @@
// This tests that the abstract iteration of ElementsAttr works properly, and
// is properly failable when necessary.
// expected-error@below {{Test iterating `int64_t`: unable to iterate type}}
// expected-error@below {{Test iterating `uint64_t`: 10, 11, 12, 13, 14}}
// expected-error@below {{Test iterating `APInt`: 10, 11, 12, 13, 14}}
// expected-error@below {{Test iterating `IntegerAttr`: 10 : i64, 11 : i64, 12 : i64, 13 : i64, 14 : i64}}
arith.constant #test.i64_elements<[10, 11, 12, 13, 14]> : tensor<5xi64>
// expected-error@below {{Test iterating `int64_t`: 10, 11, 12, 13, 14}}
// expected-error@below {{Test iterating `uint64_t`: 10, 11, 12, 13, 14}}
// expected-error@below {{Test iterating `APInt`: 10, 11, 12, 13, 14}}
// expected-error@below {{Test iterating `IntegerAttr`: 10 : i64, 11 : i64, 12 : i64, 13 : i64, 14 : i64}}
arith.constant dense<[10, 11, 12, 13, 14]> : tensor<5xi64>
// expected-error@below {{Test iterating `int64_t`: unable to iterate type}}
// expected-error@below {{Test iterating `uint64_t`: unable to iterate type}}
// expected-error@below {{Test iterating `APInt`: unable to iterate type}}
// expected-error@below {{Test iterating `IntegerAttr`: unable to iterate type}}
arith.constant opaque<"_", "0xDEADBEEF"> : tensor<5xi64>
// Check that we don't crash on empty element attributes.
// expected-error@below {{Test iterating `int64_t`: }}
// expected-error@below {{Test iterating `uint64_t`: }}
// expected-error@below {{Test iterating `APInt`: }}
// expected-error@below {{Test iterating `IntegerAttr`: }}
arith.constant dense<> : tensor<0xi64>
// expected-error@below {{Test iterating `int8_t`: 10, 11, -12, 13, 14}}
arith.constant [:i8 10, 11, -12, 13, 14]
// expected-error@below {{Test iterating `int16_t`: 10, 11, -12, 13, 14}}
arith.constant [:i16 10, 11, -12, 13, 14]
// expected-error@below {{Test iterating `int32_t`: 10, 11, -12, 13, 14}}
arith.constant [:i32 10, 11, -12, 13, 14]
// expected-error@below {{Test iterating `int64_t`: 10, 11, -12, 13, 14}}
arith.constant [:i64 10, 11, -12, 13, 14]
// expected-error@below {{Test iterating `float`: 10.00, 11.00, -12.00, 13.00, 14.00}}
arith.constant [:f32 10., 11., -12., 13., 14.]
// expected-error@below {{Test iterating `double`: 10.00, 11.00, -12.00, 13.00, 14.00}}
arith.constant [:f64 10., 11., -12., 13., 14.]

View File

@ -1654,7 +1654,7 @@ func.func @foo() {} // expected-error {{expected non-empty function body}}
// -----
// expected-error@+1 {{expected ']'}}
// expected-error@+1 {{expected ',' or ']'}}
"f"() { b = [@m:
// -----

View File

@ -270,6 +270,22 @@ def StringElementsAttrOp : TEST_Op<"string_elements_attr"> {
);
}
def DenseArrayAttrOp : TEST_Op<"dense_array_attr"> {
let arguments = (ins
DenseI8ArrayAttr:$i8attr,
DenseI16ArrayAttr:$i16attr,
DenseI32ArrayAttr:$i32attr,
DenseI64ArrayAttr:$i64attr,
DenseF32ArrayAttr:$f32attr,
DenseF64ArrayAttr:$f64attr
);
let assemblyFormat = [{
`i8attr` `=` $i8attr `i16attr` `=` $i16attr `i32attr` `=` $i32attr
`i64attr` `=` $i64attr `f32attr` `=` $f32attr `f64attr` `=` $f64attr
attr-dict
}];
}
//===----------------------------------------------------------------------===//
// Test Enum Attributes
//===----------------------------------------------------------------------===//

View File

@ -14,6 +14,17 @@
using namespace mlir;
using namespace test;
// Helper to print one scalar value, force int8_t to print as integer instead of
// char.
template <typename T>
static void printOneElement(InFlightDiagnostic &os, T value) {
os << llvm::formatv("{0}", value).str();
}
template <>
void printOneElement<int8_t>(InFlightDiagnostic &os, int8_t value) {
os << llvm::formatv("{0}", static_cast<int64_t>(value)).str();
}
namespace {
struct TestElementsAttrInterface
: public PassWrapper<TestElementsAttrInterface, OperationPass<ModuleOp>> {
@ -29,6 +40,31 @@ struct TestElementsAttrInterface
auto elementsAttr = attr.getValue().dyn_cast<ElementsAttr>();
if (!elementsAttr)
continue;
if (auto concreteAttr =
attr.getValue().dyn_cast<DenseArrayBaseAttr>()) {
switch (concreteAttr.getElementType()) {
case DenseArrayBaseAttr::EltType::I8:
testElementsAttrIteration<int8_t>(op, elementsAttr, "int8_t");
break;
case DenseArrayBaseAttr::EltType::I16:
testElementsAttrIteration<int16_t>(op, elementsAttr, "int16_t");
break;
case DenseArrayBaseAttr::EltType::I32:
testElementsAttrIteration<int32_t>(op, elementsAttr, "int32_t");
break;
case DenseArrayBaseAttr::EltType::I64:
testElementsAttrIteration<int64_t>(op, elementsAttr, "int64_t");
break;
case DenseArrayBaseAttr::EltType::F32:
testElementsAttrIteration<float>(op, elementsAttr, "float");
break;
case DenseArrayBaseAttr::EltType::F64:
testElementsAttrIteration<double>(op, elementsAttr, "double");
break;
}
continue;
}
testElementsAttrIteration<int64_t>(op, elementsAttr, "int64_t");
testElementsAttrIteration<uint64_t>(op, elementsAttr, "uint64_t");
testElementsAttrIteration<APInt>(op, elementsAttr, "APInt");
testElementsAttrIteration<IntegerAttr>(op, elementsAttr, "IntegerAttr");
@ -48,9 +84,8 @@ struct TestElementsAttrInterface
return;
}
llvm::interleaveComma(*values, diag, [&](T value) {
diag << llvm::formatv("{0}", value).str();
});
llvm::interleaveComma(*values, diag,
[&](T value) { printOneElement(diag, value); });
}
};
} // namespace