forked from OSchip/llvm-project
Introduce a new Dense Array attribute
This attribute is similar to DenseElementsAttr but does not support splat. As such it has a much simpler API and does not need any smart iterator: it exposes direct ArrayRef access. A new syntax is introduced so that the generic printing/parsing looks like: [:i64 1, -2, 3] This attribute beings like an ArrayAttr but has a `:` token after the opening square brace to introduce the element type (supported are I8, I16, I32, I64, F32, F64) and the comma separated list for the data. This is particularly convenient for attributes intended to be small, like those referring to shapes. For example a `transpose` operation with a `dims` attribute could be defined as such: let arguments = (ins AnyTensor:$input, DenseI64ArrayAttr:$dims); let assemblyFormat = "$input `dims` `=` $dims attr-dict : type($input)"; And printed this way (the element type is elided in this case): transpose %input dims = [0, 2, 1] : tensor<2x3x4xf32> The C++ API for dims would just directly return an ArrayRef<int64> RFC: https://discourse.llvm.org/t/rfc-introduce-a-new-dense-array-attribute/63279 Recommit with a custom DenseArrayBaseAttrStorage class to ensure over-alignment of the storage to the largest type. Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D123774
This commit is contained in:
parent
e2f313df8f
commit
7faf75bb3e
|
@ -66,8 +66,8 @@ template <typename T>
|
|||
struct is_complex_t<std::complex<T>> : public std::true_type {};
|
||||
} // namespace detail
|
||||
|
||||
/// An attribute that represents a reference to a dense vector or tensor object.
|
||||
///
|
||||
/// An attribute that represents a reference to a dense vector or tensor
|
||||
/// object.
|
||||
class DenseElementsAttr : public Attribute {
|
||||
public:
|
||||
using Attribute::Attribute;
|
||||
|
@ -743,6 +743,55 @@ public:
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace mlir {
|
||||
namespace detail {
|
||||
/// Base class for DenseArrayAttr that is instantiated and specialized for each
|
||||
/// supported element type below.
|
||||
template <typename T>
|
||||
class DenseArrayAttr : public DenseArrayBaseAttr {
|
||||
public:
|
||||
using DenseArrayBaseAttr::DenseArrayBaseAttr;
|
||||
|
||||
/// Implicit conversion to ArrayRef<T>.
|
||||
operator ArrayRef<T>() const;
|
||||
ArrayRef<T> asArrayRef() { return ArrayRef<T>{*this}; }
|
||||
|
||||
/// Builder from ArrayRef<T>.
|
||||
static DenseArrayAttr get(MLIRContext *context, ArrayRef<T> content);
|
||||
|
||||
/// Print the short form `[42, 100, -1]` without any type prefix.
|
||||
void print(AsmPrinter &printer) const;
|
||||
void print(raw_ostream &os) const;
|
||||
/// Print the short form `42, 100, -1` without any braces or type prefix.
|
||||
void printWithoutBraces(raw_ostream &os) const;
|
||||
|
||||
/// Parse the short form `[42, 100, -1]` without any type prefix.
|
||||
static Attribute parse(AsmParser &parser, Type odsType);
|
||||
|
||||
/// Parse the short form `42, 100, -1` without any type prefix or braces.
|
||||
static Attribute parseWithoutBraces(AsmParser &parser, Type odsType);
|
||||
|
||||
/// Support for isa<>/cast<>.
|
||||
static bool classof(Attribute attr);
|
||||
};
|
||||
template <>
|
||||
void DenseArrayAttr<int8_t>::printWithoutBraces(raw_ostream &os) const;
|
||||
|
||||
extern template class DenseArrayAttr<int8_t>;
|
||||
extern template class DenseArrayAttr<int16_t>;
|
||||
extern template class DenseArrayAttr<int32_t>;
|
||||
extern template class DenseArrayAttr<int64_t>;
|
||||
extern template class DenseArrayAttr<float>;
|
||||
extern template class DenseArrayAttr<double>;
|
||||
} // namespace detail
|
||||
|
||||
// Public name for all the supported DenseArrayAttr
|
||||
using DenseI8ArrayAttr = detail::DenseArrayAttr<int8_t>;
|
||||
using DenseI16ArrayAttr = detail::DenseArrayAttr<int16_t>;
|
||||
using DenseI32ArrayAttr = detail::DenseArrayAttr<int32_t>;
|
||||
using DenseI64ArrayAttr = detail::DenseArrayAttr<int64_t>;
|
||||
using DenseF32ArrayAttr = detail::DenseArrayAttr<float>;
|
||||
using DenseF64ArrayAttr = detail::DenseArrayAttr<double>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BoolAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -144,6 +144,78 @@ def Builtin_ArrayAttr : Builtin_Attr<"Array", [
|
|||
// DenseIntOrFPElementsAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def Builtin_DenseArrayBase : Builtin_Attr<
|
||||
"DenseArrayBase", [ElementsAttrInterface]> {
|
||||
let summary = "A dense array of i8, i16, i32, i64, f32, or f64.";
|
||||
let description = [{
|
||||
A dense array attribute is an attribute that represents a dense array of
|
||||
primitive element types. Contrary to DenseIntOrFPElementsAttr this is a
|
||||
flat unidimensional array which does not have a storage optimization for
|
||||
splat. This allows to expose the raw array through a C++ API as
|
||||
`ArrayRef<T>`. This is the base class attribute, the actual access is
|
||||
intended to be managed through the subclasses `DenseI8ArrayAttr`,
|
||||
`DenseI16ArrayAttr`, `DenseI32ArrayAttr`, `DenseI64ArrayAttr`,
|
||||
`DenseF32ArrayAttr`, and `DenseF64ArrayAttr`.
|
||||
|
||||
Syntax:
|
||||
|
||||
```
|
||||
dense-array-attribute ::= `[` `:` (integer-type | float-type) tensor-literal `]`
|
||||
```
|
||||
Examples:
|
||||
|
||||
```mlir
|
||||
[:i8]
|
||||
[:i32 10, 42]
|
||||
[:f64 42., 12.]
|
||||
```
|
||||
|
||||
when a specific subclass is used as argument of an operation, the declarative
|
||||
assembly will omit the type and print directly:
|
||||
```
|
||||
[1, 2, 3]
|
||||
```
|
||||
}];
|
||||
let parameters = (ins AttributeSelfTypeParameter<"", "ShapedType">:$type,
|
||||
"DenseArrayBaseAttr::EltType":$eltType,
|
||||
ArrayRefParameter<"char">:$elements);
|
||||
let extraClassDeclaration = [{
|
||||
// All possible supported element type.
|
||||
enum class EltType { I8, I16, I32, I64, F32, F64 };
|
||||
|
||||
/// Allow implicit conversion to ElementsAttr.
|
||||
operator ElementsAttr() const {
|
||||
return *this ? cast<ElementsAttr>() : nullptr;
|
||||
}
|
||||
|
||||
/// ElementsAttr implementation.
|
||||
using ContiguousIterableTypesT =
|
||||
std::tuple<int8_t, int16_t, int32_t, int64_t, float, double>;
|
||||
const int8_t *value_begin_impl(OverloadToken<int8_t>) const;
|
||||
const int16_t *value_begin_impl(OverloadToken<int16_t>) const;
|
||||
const int32_t *value_begin_impl(OverloadToken<int32_t>) const;
|
||||
const int64_t *value_begin_impl(OverloadToken<int64_t>) const;
|
||||
const float *value_begin_impl(OverloadToken<float>) const;
|
||||
const double *value_begin_impl(OverloadToken<double>) const;
|
||||
|
||||
/// Methods to support type inquiry through isa, cast, and dyn_cast.
|
||||
EltType getElementType() const;
|
||||
/// Printer for the short form: will dispatch to the appropriate subclass.
|
||||
void print(AsmPrinter &printer) const;
|
||||
void print(raw_ostream &os) const;
|
||||
/// Print the short form `42, 100, -1` without any braces or prefix.
|
||||
void printWithoutBraces(raw_ostream &os) const;
|
||||
}];
|
||||
// Do not generate the storage class, we need to handle custom storage alignment.
|
||||
let genStorageClass = 0;
|
||||
let genAccessors = 0;
|
||||
let skipDefaultBuilders = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DenseIntOrFPElementsAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def Builtin_DenseIntOrFPElementsAttr : Builtin_Attr<
|
||||
"DenseIntOrFPElements", [ElementsAttrInterface], "DenseElementsAttr"
|
||||
> {
|
||||
|
|
|
@ -1258,6 +1258,19 @@ class IntElementsAttrBase<Pred condition, string summary> :
|
|||
let convertFromStorage = "$_self";
|
||||
}
|
||||
|
||||
class DenseArrayAttrBase<string denseAttrName, string cppType, string summaryName> :
|
||||
ElementsAttrBase<CPred<"$_self.isa<::mlir::" # denseAttrName # ">()">,
|
||||
summaryName # " dense array attribute"> {
|
||||
let storageType = "::mlir::" # denseAttrName;
|
||||
let returnType = "::llvm::ArrayRef<" # cppType # ">";
|
||||
}
|
||||
def DenseI8ArrayAttr : DenseArrayAttrBase<"DenseI8ArrayAttr", "int8_t", "i8">;
|
||||
def DenseI16ArrayAttr : DenseArrayAttrBase<"DenseI16ArrayAttr", "int16_t", "i16">;
|
||||
def DenseI32ArrayAttr : DenseArrayAttrBase<"DenseI32ArrayAttr", "int32_t", "i32">;
|
||||
def DenseI64ArrayAttr : DenseArrayAttrBase<"DenseI64ArrayAttr", "int64_t", "i64">;
|
||||
def DenseF32ArrayAttr : DenseArrayAttrBase<"DenseF32ArrayAttr", "float", "f32">;
|
||||
def DenseF64ArrayAttr : DenseArrayAttrBase<"DenseF64ArrayAttr", "double", "f64">;
|
||||
|
||||
def IndexElementsAttr
|
||||
: IntElementsAttrBase<CPred<[{$_self.cast<::mlir::DenseIntElementsAttr>()
|
||||
.getType()
|
||||
|
|
|
@ -1878,9 +1878,34 @@ void AsmPrinter::Impl::printAttribute(Attribute attr,
|
|||
}
|
||||
os << '>';
|
||||
}
|
||||
|
||||
} else if (auto denseArrayAttr = attr.dyn_cast<DenseArrayBaseAttr>()) {
|
||||
typeElision = AttrTypeElision::Must;
|
||||
switch (denseArrayAttr.getElementType()) {
|
||||
case DenseArrayBaseAttr::EltType::I8:
|
||||
os << "[:i8 ";
|
||||
break;
|
||||
case DenseArrayBaseAttr::EltType::I16:
|
||||
os << "[:i16 ";
|
||||
break;
|
||||
case DenseArrayBaseAttr::EltType::I32:
|
||||
os << "[:i32 ";
|
||||
break;
|
||||
case DenseArrayBaseAttr::EltType::I64:
|
||||
os << "[:i64 ";
|
||||
break;
|
||||
case DenseArrayBaseAttr::EltType::F32:
|
||||
os << "[:f32 ";
|
||||
break;
|
||||
case DenseArrayBaseAttr::EltType::F64:
|
||||
os << "[:f64 ";
|
||||
break;
|
||||
}
|
||||
denseArrayAttr.printWithoutBraces(os);
|
||||
os << "]";
|
||||
} else if (auto locAttr = attr.dyn_cast<LocationAttr>()) {
|
||||
printLocation(locAttr);
|
||||
} else {
|
||||
llvm::report_fatal_error("Unknown builtin attribute");
|
||||
}
|
||||
// Don't print the type if we must elide it, or if it is a None type.
|
||||
if (typeElision != AttrTypeElision::Must && !attrType.isa<NoneType>()) {
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
#include "mlir/IR/BuiltinDialect.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/IntegerSet.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
|
@ -35,11 +36,11 @@ using namespace mlir::detail;
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void BuiltinDialect::registerAttributes() {
|
||||
addAttributes<AffineMapAttr, ArrayAttr, DenseIntOrFPElementsAttr,
|
||||
DenseStringElementsAttr, DictionaryAttr, FloatAttr,
|
||||
SymbolRefAttr, IntegerAttr, IntegerSetAttr, OpaqueAttr,
|
||||
OpaqueElementsAttr, SparseElementsAttr, StringAttr, TypeAttr,
|
||||
UnitAttr>();
|
||||
addAttributes<AffineMapAttr, ArrayAttr, DenseArrayBaseAttr,
|
||||
DenseIntOrFPElementsAttr, DenseStringElementsAttr,
|
||||
DictionaryAttr, FloatAttr, SymbolRefAttr, IntegerAttr,
|
||||
IntegerSetAttr, OpaqueAttr, OpaqueElementsAttr,
|
||||
SparseElementsAttr, StringAttr, TypeAttr, UnitAttr>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -664,6 +665,274 @@ DenseElementsAttr::ComplexIntElementIterator::operator*() const {
|
|||
readBits(getData(), offset + storageWidth, bitWidth)};
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DenseArrayAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Custom storage to ensure proper memory alignment for the allocation of
|
||||
/// DenseArray of any element type.
|
||||
struct ::mlir::detail::DenseArrayBaseAttrStorage : public AttributeStorage {
|
||||
using KeyTy = std::tuple<ShapedType, DenseArrayBaseAttr::EltType,
|
||||
::llvm::ArrayRef<char>>;
|
||||
DenseArrayBaseAttrStorage(ShapedType type,
|
||||
DenseArrayBaseAttr::EltType eltType,
|
||||
::llvm::ArrayRef<char> elements)
|
||||
: AttributeStorage(type), eltType(eltType), elements(elements) {}
|
||||
|
||||
bool operator==(const KeyTy &tblgenKey) const {
|
||||
return (getType() == std::get<0>(tblgenKey)) &&
|
||||
(eltType == std::get<1>(tblgenKey)) &&
|
||||
(elements == std::get<2>(tblgenKey));
|
||||
}
|
||||
|
||||
static ::llvm::hash_code hashKey(const KeyTy &tblgenKey) {
|
||||
return ::llvm::hash_combine(std::get<0>(tblgenKey), std::get<1>(tblgenKey),
|
||||
std::get<2>(tblgenKey));
|
||||
}
|
||||
|
||||
static DenseArrayBaseAttrStorage *
|
||||
construct(AttributeStorageAllocator &allocator, const KeyTy &tblgenKey) {
|
||||
auto type = std::get<0>(tblgenKey);
|
||||
auto eltType = std::get<1>(tblgenKey);
|
||||
auto elements = std::get<2>(tblgenKey);
|
||||
if (!elements.empty()) {
|
||||
char *alloc = static_cast<char *>(
|
||||
allocator.allocate(elements.size(), alignof(uint64_t)));
|
||||
std::uninitialized_copy(elements.begin(), elements.end(), alloc);
|
||||
elements = ArrayRef<char>(alloc, elements.size());
|
||||
}
|
||||
return new (allocator.allocate<DenseArrayBaseAttrStorage>())
|
||||
DenseArrayBaseAttrStorage(type, eltType, elements);
|
||||
}
|
||||
|
||||
DenseArrayBaseAttr::EltType eltType;
|
||||
::llvm::ArrayRef<char> elements;
|
||||
};
|
||||
|
||||
DenseArrayBaseAttr::EltType DenseArrayBaseAttr::getElementType() const {
|
||||
return getImpl()->eltType;
|
||||
}
|
||||
|
||||
const int8_t *
|
||||
DenseArrayBaseAttr::value_begin_impl(OverloadToken<int8_t>) const {
|
||||
return cast<DenseI8ArrayAttr>().asArrayRef().begin();
|
||||
}
|
||||
const int16_t *
|
||||
DenseArrayBaseAttr::value_begin_impl(OverloadToken<int16_t>) const {
|
||||
return cast<DenseI16ArrayAttr>().asArrayRef().begin();
|
||||
}
|
||||
const int32_t *
|
||||
DenseArrayBaseAttr::value_begin_impl(OverloadToken<int32_t>) const {
|
||||
return cast<DenseI32ArrayAttr>().asArrayRef().begin();
|
||||
}
|
||||
const int64_t *
|
||||
DenseArrayBaseAttr::value_begin_impl(OverloadToken<int64_t>) const {
|
||||
return cast<DenseI64ArrayAttr>().asArrayRef().begin();
|
||||
}
|
||||
const float *DenseArrayBaseAttr::value_begin_impl(OverloadToken<float>) const {
|
||||
return cast<DenseF32ArrayAttr>().asArrayRef().begin();
|
||||
}
|
||||
const double *
|
||||
DenseArrayBaseAttr::value_begin_impl(OverloadToken<double>) const {
|
||||
return cast<DenseF64ArrayAttr>().asArrayRef().begin();
|
||||
}
|
||||
|
||||
void DenseArrayBaseAttr::print(AsmPrinter &printer) const {
|
||||
print(printer.getStream());
|
||||
}
|
||||
|
||||
void DenseArrayBaseAttr::printWithoutBraces(raw_ostream &os) const {
|
||||
switch (getElementType()) {
|
||||
case DenseArrayBaseAttr::EltType::I8:
|
||||
this->cast<DenseI8ArrayAttr>().printWithoutBraces(os);
|
||||
return;
|
||||
case DenseArrayBaseAttr::EltType::I16:
|
||||
this->cast<DenseI16ArrayAttr>().printWithoutBraces(os);
|
||||
return;
|
||||
case DenseArrayBaseAttr::EltType::I32:
|
||||
this->cast<DenseI32ArrayAttr>().printWithoutBraces(os);
|
||||
return;
|
||||
case DenseArrayBaseAttr::EltType::I64:
|
||||
this->cast<DenseI64ArrayAttr>().printWithoutBraces(os);
|
||||
return;
|
||||
case DenseArrayBaseAttr::EltType::F32:
|
||||
this->cast<DenseF32ArrayAttr>().printWithoutBraces(os);
|
||||
return;
|
||||
case DenseArrayBaseAttr::EltType::F64:
|
||||
this->cast<DenseF64ArrayAttr>().printWithoutBraces(os);
|
||||
return;
|
||||
}
|
||||
llvm_unreachable("<unknown DenseArrayBaseAttr>");
|
||||
}
|
||||
|
||||
void DenseArrayBaseAttr::print(raw_ostream &os) const {
|
||||
os << "[";
|
||||
printWithoutBraces(os);
|
||||
os << "]";
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DenseArrayAttr<T>::print(AsmPrinter &printer) const {
|
||||
print(printer.getStream());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DenseArrayAttr<T>::printWithoutBraces(raw_ostream &os) const {
|
||||
ArrayRef<T> values{*this};
|
||||
llvm::interleaveComma(values, os);
|
||||
}
|
||||
|
||||
/// Specialization for int8_t for forcing printing as number instead of chars.
|
||||
template <>
|
||||
void DenseArrayAttr<int8_t>::printWithoutBraces(raw_ostream &os) const {
|
||||
ArrayRef<int8_t> values{*this};
|
||||
llvm::interleaveComma(values, os, [&](int64_t v) { os << v; });
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DenseArrayAttr<T>::print(raw_ostream &os) const {
|
||||
os << "[";
|
||||
printWithoutBraces(os);
|
||||
os << "]";
|
||||
}
|
||||
|
||||
/// Parse a single element: generic template for int types, specialized for
|
||||
/// floating points below.
|
||||
template <typename T>
|
||||
static ParseResult parseDenseArrayAttrElt(AsmParser &parser, T &value) {
|
||||
return parser.parseInteger(value);
|
||||
}
|
||||
|
||||
template <>
|
||||
ParseResult parseDenseArrayAttrElt<float>(AsmParser &parser, float &value) {
|
||||
double doubleVal;
|
||||
if (parser.parseFloat(doubleVal))
|
||||
return failure();
|
||||
value = doubleVal;
|
||||
return success();
|
||||
}
|
||||
|
||||
template <>
|
||||
ParseResult parseDenseArrayAttrElt<double>(AsmParser &parser, double &value) {
|
||||
return parser.parseFloat(value);
|
||||
}
|
||||
|
||||
/// Parse a DenseArrayAttr without the braces: `1, 2, 3`
|
||||
template <typename T>
|
||||
Attribute DenseArrayAttr<T>::parseWithoutBraces(AsmParser &parser,
|
||||
Type odsType) {
|
||||
SmallVector<T> data;
|
||||
if (failed(parser.parseCommaSeparatedList([&]() {
|
||||
T value;
|
||||
if (parseDenseArrayAttrElt(parser, value))
|
||||
return failure();
|
||||
data.push_back(value);
|
||||
return success();
|
||||
})))
|
||||
return {};
|
||||
return get(parser.getContext(), data);
|
||||
}
|
||||
|
||||
/// Parse a DenseArrayAttr: `[ 1, 2, 3 ]`
|
||||
template <typename T>
|
||||
Attribute DenseArrayAttr<T>::parse(AsmParser &parser, Type odsType) {
|
||||
if (parser.parseLSquare())
|
||||
return {};
|
||||
Attribute result = parseWithoutBraces(parser, odsType);
|
||||
if (parser.parseRSquare())
|
||||
return {};
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Conversion from DenseArrayAttr<T> to ArrayRef<T>.
|
||||
template <typename T>
|
||||
DenseArrayAttr<T>::operator ArrayRef<T>() const {
|
||||
ArrayRef<char> raw = getImpl()->elements;
|
||||
assert((raw.size() % sizeof(T)) == 0);
|
||||
return ArrayRef<T>(reinterpret_cast<const T *>(raw.data()),
|
||||
raw.size() / sizeof(T));
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// Mapping from C++ element type to MLIR DenseArrayAttr internals.
|
||||
template <typename T>
|
||||
struct denseArrayAttrEltTypeBuilder;
|
||||
template <>
|
||||
struct denseArrayAttrEltTypeBuilder<int8_t> {
|
||||
constexpr static auto eltType = DenseArrayBaseAttr::EltType::I8;
|
||||
static ShapedType getShapedType(MLIRContext *context, int64_t shape) {
|
||||
return VectorType::get(shape, IntegerType::get(context, 8));
|
||||
}
|
||||
};
|
||||
template <>
|
||||
struct denseArrayAttrEltTypeBuilder<int16_t> {
|
||||
constexpr static auto eltType = DenseArrayBaseAttr::EltType::I16;
|
||||
static ShapedType getShapedType(MLIRContext *context, int64_t shape) {
|
||||
return VectorType::get(shape, IntegerType::get(context, 16));
|
||||
}
|
||||
};
|
||||
template <>
|
||||
struct denseArrayAttrEltTypeBuilder<int32_t> {
|
||||
constexpr static auto eltType = DenseArrayBaseAttr::EltType::I32;
|
||||
static ShapedType getShapedType(MLIRContext *context, int64_t shape) {
|
||||
return VectorType::get(shape, IntegerType::get(context, 32));
|
||||
}
|
||||
};
|
||||
template <>
|
||||
struct denseArrayAttrEltTypeBuilder<int64_t> {
|
||||
constexpr static auto eltType = DenseArrayBaseAttr::EltType::I64;
|
||||
static ShapedType getShapedType(MLIRContext *context, int64_t shape) {
|
||||
return VectorType::get(shape, IntegerType::get(context, 64));
|
||||
}
|
||||
};
|
||||
template <>
|
||||
struct denseArrayAttrEltTypeBuilder<float> {
|
||||
constexpr static auto eltType = DenseArrayBaseAttr::EltType::F32;
|
||||
static ShapedType getShapedType(MLIRContext *context, int64_t shape) {
|
||||
return VectorType::get(shape, Float32Type::get(context));
|
||||
}
|
||||
};
|
||||
template <>
|
||||
struct denseArrayAttrEltTypeBuilder<double> {
|
||||
constexpr static auto eltType = DenseArrayBaseAttr::EltType::F64;
|
||||
static ShapedType getShapedType(MLIRContext *context, int64_t shape) {
|
||||
return VectorType::get(shape, Float64Type::get(context));
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
/// Builds a DenseArrayAttr<T> from an ArrayRef<T>.
|
||||
template <typename T>
|
||||
DenseArrayAttr<T> DenseArrayAttr<T>::get(MLIRContext *context,
|
||||
ArrayRef<T> content) {
|
||||
auto shapedType =
|
||||
denseArrayAttrEltTypeBuilder<T>::getShapedType(context, content.size());
|
||||
auto eltType = denseArrayAttrEltTypeBuilder<T>::eltType;
|
||||
auto rawArray = ArrayRef<char>(reinterpret_cast<const char *>(content.data()),
|
||||
content.size() * sizeof(T));
|
||||
return Base::get(context, shapedType, eltType, rawArray)
|
||||
.template cast<DenseArrayAttr<T>>();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool DenseArrayAttr<T>::classof(Attribute attr) {
|
||||
return attr.isa<DenseArrayBaseAttr>() &&
|
||||
attr.cast<DenseArrayBaseAttr>().getElementType() ==
|
||||
denseArrayAttrEltTypeBuilder<T>::eltType;
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
namespace detail {
|
||||
// Explicit instantiation for all the supported DenseArrayAttr.
|
||||
template class DenseArrayAttr<int8_t>;
|
||||
template class DenseArrayAttr<int16_t>;
|
||||
template class DenseArrayAttr<int32_t>;
|
||||
template class DenseArrayAttr<int64_t>;
|
||||
template class DenseArrayAttr<float>;
|
||||
template class DenseArrayAttr<double>;
|
||||
} // namespace detail
|
||||
} // namespace mlir
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DenseElementsAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -11,9 +11,12 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "Parser.h"
|
||||
|
||||
#include "AsmParserImpl.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "mlir/IR/IntegerSet.h"
|
||||
#include "mlir/Parser/AsmParserState.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
|
@ -30,6 +33,7 @@ using namespace mlir::detail;
|
|||
/// | float-literal (`:` float-type)?
|
||||
/// | string-literal (`:` type)?
|
||||
/// | type
|
||||
/// | `[` `:` (integer-type | float-type) tensor-literal `]`
|
||||
/// | `[` (attribute-value (`,` attribute-value)*)? `]`
|
||||
/// | `{` (attribute-entry (`,` attribute-entry)*)? `}`
|
||||
/// | symbol-ref-id (`::` symbol-ref-id)*
|
||||
|
@ -67,13 +71,16 @@ Attribute Parser::parseAttribute(Type type) {
|
|||
|
||||
// Parse an array attribute.
|
||||
case Token::l_square: {
|
||||
consumeToken(Token::l_square);
|
||||
if (consumeIf(Token::colon))
|
||||
return parseDenseArrayAttr();
|
||||
SmallVector<Attribute, 4> elements;
|
||||
auto parseElt = [&]() -> ParseResult {
|
||||
elements.push_back(parseAttribute());
|
||||
return elements.back() ? success() : failure();
|
||||
};
|
||||
|
||||
if (parseCommaSeparatedList(Delimiter::Square, parseElt))
|
||||
if (parseCommaSeparatedListUntil(Token::r_square, parseElt))
|
||||
return nullptr;
|
||||
return builder.getArrayAttr(elements);
|
||||
}
|
||||
|
@ -812,6 +819,66 @@ ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) {
|
|||
// ElementsAttr Parser
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
/// This class provides an implementation of AsmParser, allowing to call back
|
||||
/// into the libMLIRIR-provided APIs for invoking attribute parsing code defined
|
||||
/// in libMLIRIR.
|
||||
class CustomAsmParser : public AsmParserImpl<AsmParser> {
|
||||
public:
|
||||
CustomAsmParser(Parser &parser)
|
||||
: AsmParserImpl<AsmParser>(parser.getToken().getLoc(), parser) {}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
/// Parse a dense array attribute.
|
||||
Attribute Parser::parseDenseArrayAttr() {
|
||||
auto typeLoc = getToken().getLoc();
|
||||
auto type = parseType();
|
||||
if (!type)
|
||||
return {};
|
||||
CustomAsmParser parser(*this);
|
||||
Attribute result;
|
||||
if (auto intType = type.dyn_cast<IntegerType>()) {
|
||||
switch (type.getIntOrFloatBitWidth()) {
|
||||
case 8:
|
||||
result = DenseI8ArrayAttr::parseWithoutBraces(parser, Type{});
|
||||
break;
|
||||
case 16:
|
||||
result = DenseI16ArrayAttr::parseWithoutBraces(parser, Type{});
|
||||
break;
|
||||
case 32:
|
||||
result = DenseI32ArrayAttr::parseWithoutBraces(parser, Type{});
|
||||
break;
|
||||
case 64:
|
||||
result = DenseI64ArrayAttr::parseWithoutBraces(parser, Type{});
|
||||
break;
|
||||
default:
|
||||
emitError(typeLoc, "expected i8, i16, i32, or i64 but got: ") << type;
|
||||
return {};
|
||||
}
|
||||
} else if (auto floatType = type.dyn_cast<FloatType>()) {
|
||||
switch (type.getIntOrFloatBitWidth()) {
|
||||
case 32:
|
||||
result = DenseF32ArrayAttr::parseWithoutBraces(parser, Type{});
|
||||
break;
|
||||
case 64:
|
||||
result = DenseF64ArrayAttr::parseWithoutBraces(parser, Type{});
|
||||
break;
|
||||
default:
|
||||
emitError(typeLoc, "expected f32 or f64 but got: ") << type;
|
||||
return {};
|
||||
}
|
||||
} else {
|
||||
emitError(typeLoc, "expected integer or float type, got: ") << type;
|
||||
return {};
|
||||
}
|
||||
if (!consumeIf(Token::r_square)) {
|
||||
emitError("expected ']' to close an array attribute");
|
||||
return {};
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Parse a dense elements attribute.
|
||||
Attribute Parser::parseDenseElementsAttr(Type attrType) {
|
||||
auto attribLoc = getToken().getLoc();
|
||||
|
|
|
@ -264,6 +264,9 @@ public:
|
|||
Attribute parseDenseElementsAttr(Type attrType);
|
||||
ShapedType parseElementsLiteralType(Type type);
|
||||
|
||||
/// Parse a DenseArrayAttr.
|
||||
Attribute parseDenseArrayAttr();
|
||||
|
||||
/// Parse a sparse elements attribute.
|
||||
Attribute parseSparseElementsAttr(Type attrType);
|
||||
|
||||
|
|
|
@ -513,6 +513,45 @@ func.func @simple_scalar_example() {
|
|||
return
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test DenseArrayAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK-LABEL: func @dense_array_attr
|
||||
func.func @dense_array_attr() attributes{
|
||||
// CHECK-SAME: f32attr = [:f32 1.024000e+03, 4.530000e+02, -6.435000e+03],
|
||||
f32attr = [:f32 1024., 453., -6435.],
|
||||
// CHECK-SAME: f64attr = [:f64 -1.420000e+02],
|
||||
f64attr = [:f64 -142.],
|
||||
// CHECK-SAME: i16attr = [:i16 3, 5, -4, 10],
|
||||
i16attr = [:i16 3, 5, -4, 10],
|
||||
// CHECK-SAME: i32attr = [:i32 1024, 453, -6435],
|
||||
i32attr = [:i32 1024, 453, -6435],
|
||||
// CHECK-SAME: i64attr = [:i64 -142],
|
||||
i64attr = [:i64 -142],
|
||||
// CHECK-SAME: i8attr = [:i8 1, -2, 3]
|
||||
i8attr = [:i8 1, -2, 3]
|
||||
} {
|
||||
// CHECK: test.dense_array_attr
|
||||
test.dense_array_attr
|
||||
// CHECK-SAME: i8attr = [1, -2, 3]
|
||||
i8attr = [1, -2, 3]
|
||||
// CHECK-SAME: i16attr = [3, 5, -4, 10]
|
||||
i16attr = [3, 5, -4, 10]
|
||||
// CHECK-SAME: i32attr = [1024, 453, -6435]
|
||||
i32attr = [1024, 453, -6435]
|
||||
// CHECK-SAME: i64attr = [-142]
|
||||
i64attr = [-142]
|
||||
// CHECK-SAME: f32attr = [1.024000e+03, 4.530000e+02, -6.435000e+03]
|
||||
f32attr = [1024., 453., -6435.]
|
||||
// CHECK-SAME: f64attr = [-1.420000e+02]
|
||||
f64attr = [-142.]
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -5,23 +5,40 @@
|
|||
// This tests that the abstract iteration of ElementsAttr works properly, and
|
||||
// is properly failable when necessary.
|
||||
|
||||
// expected-error@below {{Test iterating `int64_t`: unable to iterate type}}
|
||||
// expected-error@below {{Test iterating `uint64_t`: 10, 11, 12, 13, 14}}
|
||||
// expected-error@below {{Test iterating `APInt`: 10, 11, 12, 13, 14}}
|
||||
// expected-error@below {{Test iterating `IntegerAttr`: 10 : i64, 11 : i64, 12 : i64, 13 : i64, 14 : i64}}
|
||||
arith.constant #test.i64_elements<[10, 11, 12, 13, 14]> : tensor<5xi64>
|
||||
|
||||
// expected-error@below {{Test iterating `int64_t`: 10, 11, 12, 13, 14}}
|
||||
// expected-error@below {{Test iterating `uint64_t`: 10, 11, 12, 13, 14}}
|
||||
// expected-error@below {{Test iterating `APInt`: 10, 11, 12, 13, 14}}
|
||||
// expected-error@below {{Test iterating `IntegerAttr`: 10 : i64, 11 : i64, 12 : i64, 13 : i64, 14 : i64}}
|
||||
arith.constant dense<[10, 11, 12, 13, 14]> : tensor<5xi64>
|
||||
|
||||
// expected-error@below {{Test iterating `int64_t`: unable to iterate type}}
|
||||
// expected-error@below {{Test iterating `uint64_t`: unable to iterate type}}
|
||||
// expected-error@below {{Test iterating `APInt`: unable to iterate type}}
|
||||
// expected-error@below {{Test iterating `IntegerAttr`: unable to iterate type}}
|
||||
arith.constant opaque<"_", "0xDEADBEEF"> : tensor<5xi64>
|
||||
|
||||
// Check that we don't crash on empty element attributes.
|
||||
// expected-error@below {{Test iterating `int64_t`: }}
|
||||
// expected-error@below {{Test iterating `uint64_t`: }}
|
||||
// expected-error@below {{Test iterating `APInt`: }}
|
||||
// expected-error@below {{Test iterating `IntegerAttr`: }}
|
||||
arith.constant dense<> : tensor<0xi64>
|
||||
|
||||
// expected-error@below {{Test iterating `int8_t`: 10, 11, -12, 13, 14}}
|
||||
arith.constant [:i8 10, 11, -12, 13, 14]
|
||||
// expected-error@below {{Test iterating `int16_t`: 10, 11, -12, 13, 14}}
|
||||
arith.constant [:i16 10, 11, -12, 13, 14]
|
||||
// expected-error@below {{Test iterating `int32_t`: 10, 11, -12, 13, 14}}
|
||||
arith.constant [:i32 10, 11, -12, 13, 14]
|
||||
// expected-error@below {{Test iterating `int64_t`: 10, 11, -12, 13, 14}}
|
||||
arith.constant [:i64 10, 11, -12, 13, 14]
|
||||
// expected-error@below {{Test iterating `float`: 10.00, 11.00, -12.00, 13.00, 14.00}}
|
||||
arith.constant [:f32 10., 11., -12., 13., 14.]
|
||||
// expected-error@below {{Test iterating `double`: 10.00, 11.00, -12.00, 13.00, 14.00}}
|
||||
arith.constant [:f64 10., 11., -12., 13., 14.]
|
||||
|
|
|
@ -1654,7 +1654,7 @@ func.func @foo() {} // expected-error {{expected non-empty function body}}
|
|||
|
||||
// -----
|
||||
|
||||
// expected-error@+1 {{expected ']'}}
|
||||
// expected-error@+1 {{expected ',' or ']'}}
|
||||
"f"() { b = [@m:
|
||||
|
||||
// -----
|
||||
|
|
|
@ -270,6 +270,22 @@ def StringElementsAttrOp : TEST_Op<"string_elements_attr"> {
|
|||
);
|
||||
}
|
||||
|
||||
def DenseArrayAttrOp : TEST_Op<"dense_array_attr"> {
|
||||
let arguments = (ins
|
||||
DenseI8ArrayAttr:$i8attr,
|
||||
DenseI16ArrayAttr:$i16attr,
|
||||
DenseI32ArrayAttr:$i32attr,
|
||||
DenseI64ArrayAttr:$i64attr,
|
||||
DenseF32ArrayAttr:$f32attr,
|
||||
DenseF64ArrayAttr:$f64attr
|
||||
);
|
||||
let assemblyFormat = [{
|
||||
`i8attr` `=` $i8attr `i16attr` `=` $i16attr `i32attr` `=` $i32attr
|
||||
`i64attr` `=` $i64attr `f32attr` `=` $f32attr `f64attr` `=` $f64attr
|
||||
attr-dict
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test Enum Attributes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -14,6 +14,17 @@
|
|||
using namespace mlir;
|
||||
using namespace test;
|
||||
|
||||
// Helper to print one scalar value, force int8_t to print as integer instead of
|
||||
// char.
|
||||
template <typename T>
|
||||
static void printOneElement(InFlightDiagnostic &os, T value) {
|
||||
os << llvm::formatv("{0}", value).str();
|
||||
}
|
||||
template <>
|
||||
void printOneElement<int8_t>(InFlightDiagnostic &os, int8_t value) {
|
||||
os << llvm::formatv("{0}", static_cast<int64_t>(value)).str();
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct TestElementsAttrInterface
|
||||
: public PassWrapper<TestElementsAttrInterface, OperationPass<ModuleOp>> {
|
||||
|
@ -29,6 +40,31 @@ struct TestElementsAttrInterface
|
|||
auto elementsAttr = attr.getValue().dyn_cast<ElementsAttr>();
|
||||
if (!elementsAttr)
|
||||
continue;
|
||||
if (auto concreteAttr =
|
||||
attr.getValue().dyn_cast<DenseArrayBaseAttr>()) {
|
||||
switch (concreteAttr.getElementType()) {
|
||||
case DenseArrayBaseAttr::EltType::I8:
|
||||
testElementsAttrIteration<int8_t>(op, elementsAttr, "int8_t");
|
||||
break;
|
||||
case DenseArrayBaseAttr::EltType::I16:
|
||||
testElementsAttrIteration<int16_t>(op, elementsAttr, "int16_t");
|
||||
break;
|
||||
case DenseArrayBaseAttr::EltType::I32:
|
||||
testElementsAttrIteration<int32_t>(op, elementsAttr, "int32_t");
|
||||
break;
|
||||
case DenseArrayBaseAttr::EltType::I64:
|
||||
testElementsAttrIteration<int64_t>(op, elementsAttr, "int64_t");
|
||||
break;
|
||||
case DenseArrayBaseAttr::EltType::F32:
|
||||
testElementsAttrIteration<float>(op, elementsAttr, "float");
|
||||
break;
|
||||
case DenseArrayBaseAttr::EltType::F64:
|
||||
testElementsAttrIteration<double>(op, elementsAttr, "double");
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
testElementsAttrIteration<int64_t>(op, elementsAttr, "int64_t");
|
||||
testElementsAttrIteration<uint64_t>(op, elementsAttr, "uint64_t");
|
||||
testElementsAttrIteration<APInt>(op, elementsAttr, "APInt");
|
||||
testElementsAttrIteration<IntegerAttr>(op, elementsAttr, "IntegerAttr");
|
||||
|
@ -48,9 +84,8 @@ struct TestElementsAttrInterface
|
|||
return;
|
||||
}
|
||||
|
||||
llvm::interleaveComma(*values, diag, [&](T value) {
|
||||
diag << llvm::formatv("{0}", value).str();
|
||||
});
|
||||
llvm::interleaveComma(*values, diag,
|
||||
[&](T value) { printOneElement(diag, value); });
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
|
Loading…
Reference in New Issue