diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h index 99bb2b68cc62..ac98bfef1b58 100644 --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -16,160 +16,14 @@ namespace mlir { class AffineMap; +class BoolAttr; +class DenseIntElementsAttr; class FlatSymbolRefAttr; class FunctionType; class IntegerSet; +class IntegerType; class Location; class ShapedType; -} // namespace mlir - -//===----------------------------------------------------------------------===// -// Tablegen Attribute Declarations -//===----------------------------------------------------------------------===// - -#define GET_ATTRDEF_CLASSES -#include "mlir/IR/BuiltinAttributes.h.inc" - -//===----------------------------------------------------------------------===// -// C++ Attribute Declarations -//===----------------------------------------------------------------------===// - -namespace mlir { -namespace detail { - -struct IntegerAttributeStorage; -struct FloatAttributeStorage; -struct SymbolRefAttributeStorage; -struct TypeAttributeStorage; - -/// Elements Attributes. -struct DenseIntOrFPElementsAttributeStorage; -struct DenseStringElementsAttributeStorage; -struct OpaqueElementsAttributeStorage; -struct SparseElementsAttributeStorage; -} // namespace detail - -//===----------------------------------------------------------------------===// -// FloatAttr -//===----------------------------------------------------------------------===// - -class FloatAttr : public Attribute::AttrBase { -public: - using Base::Base; - using Base::getChecked; - using ValueType = APFloat; - - /// Return a float attribute for the specified value in the specified type. - /// These methods should only be used for simple constant values, e.g 1.0/2.0, - /// that are known-valid both as host double and the 'type' format. - static FloatAttr get(Type type, double value); - static FloatAttr getChecked(function_ref emitError, - Type type, double value); - - /// Return a float attribute for the specified value in the specified type. - static FloatAttr get(Type type, const APFloat &value); - static FloatAttr getChecked(function_ref emitError, - Type type, const APFloat &value); - - APFloat getValue() const; - - /// This function is used to convert the value to a double, even if it loses - /// precision. - double getValueAsDouble() const; - static double getValueAsDouble(APFloat val); - - /// Verify the construction invariants for a double value. - static LogicalResult verify(function_ref emitError, - Type type, double value); - static LogicalResult verify(function_ref emitError, - Type type, const APFloat &value); -}; - -//===----------------------------------------------------------------------===// -// IntegerAttr -//===----------------------------------------------------------------------===// - -class IntegerAttr - : public Attribute::AttrBase { -public: - using Base::Base; - using ValueType = APInt; - - static IntegerAttr get(Type type, int64_t value); - static IntegerAttr get(Type type, const APInt &value); - - APInt getValue() const; - /// Return the integer value as a 64-bit int. The attribute must be a signless - /// integer. - // TODO: Change callers to use getValue instead. - int64_t getInt() const; - /// Return the integer value as a signed 64-bit int. The attribute must be - /// a signed integer. - int64_t getSInt() const; - /// Return the integer value as a unsigned 64-bit int. The attribute must be - /// an unsigned integer. - uint64_t getUInt() const; - - static LogicalResult verify(function_ref emitError, - Type type, int64_t value); - static LogicalResult verify(function_ref emitError, - Type type, const APInt &value); -}; - -//===----------------------------------------------------------------------===// -// BoolAttr - -/// Special case of IntegerAttr to represent boolean integers, i.e., signless i1 -/// integers. -class BoolAttr : public Attribute { -public: - using Attribute::Attribute; - using ValueType = bool; - - static BoolAttr get(MLIRContext *context, bool value); - - /// Enable conversion to IntegerAttr. This uses conversion vs. inheritance to - /// avoid bringing in all of IntegerAttrs methods. - operator IntegerAttr() const { return IntegerAttr(impl); } - - /// Return the boolean value of this attribute. - bool getValue() const; - - /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool classof(Attribute attr); -}; - -//===----------------------------------------------------------------------===// -// FlatSymbolRefAttr -//===----------------------------------------------------------------------===// - -/// A symbol reference with a reference path containing a single element. This -/// is used to refer to an operation within the current symbol table. -class FlatSymbolRefAttr : public SymbolRefAttr { -public: - using SymbolRefAttr::SymbolRefAttr; - using ValueType = StringRef; - - /// Construct a symbol reference for the given value name. - static FlatSymbolRefAttr get(MLIRContext *ctx, StringRef value) { - return SymbolRefAttr::get(ctx, value); - } - - /// Returns the name of the held symbol reference. - StringRef getValue() const { return getRootReference(); } - - /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool classof(Attribute attr) { - SymbolRefAttr refAttr = attr.dyn_cast(); - return refAttr && refAttr.getNestedReferences().empty(); - } - -private: - using SymbolRefAttr::get; - using SymbolRefAttr::getNestedReferences; -}; //===----------------------------------------------------------------------===// // Elements Attributes @@ -751,88 +605,91 @@ protected: bool isValidComplex(int64_t dataEltSize, bool isInt, bool isSigned) const; }; -/// An attribute class for representing dense arrays of strings. The structure -/// storing and querying a list of densely packed strings. -class DenseStringElementsAttr - : public Attribute::AttrBase { - +/// An attribute that represents a reference to a splat vector or tensor +/// constant, meaning all of the elements have the same value. +class SplatElementsAttr : public DenseElementsAttr { public: - using Base::Base; + using DenseElementsAttr::DenseElementsAttr; - /// Overload of the raw 'get' method that asserts that the given type is of - /// integer or floating-point type. This method is used to verify type - /// invariants that the templatized 'get' method cannot. - static DenseStringElementsAttr get(ShapedType type, ArrayRef data); - -protected: - friend DenseElementsAttr; + /// Method for support type inquiry through isa, cast and dyn_cast. + static bool classof(Attribute attr) { + auto denseAttr = attr.dyn_cast(); + return denseAttr && denseAttr.isSplat(); + } }; -/// An attribute class for specializing behavior of Int and Floating-point -/// densely packed string arrays. -class DenseIntOrFPElementsAttr - : public Attribute::AttrBase { +} // namespace mlir +//===----------------------------------------------------------------------===// +// Tablegen Attribute Declarations +//===----------------------------------------------------------------------===// + +#define GET_ATTRDEF_CLASSES +#include "mlir/IR/BuiltinAttributes.h.inc" + +//===----------------------------------------------------------------------===// +// C++ Attribute Declarations +//===----------------------------------------------------------------------===// + +namespace mlir { +//===----------------------------------------------------------------------===// +// BoolAttr +//===----------------------------------------------------------------------===// + +/// Special case of IntegerAttr to represent boolean integers, i.e., signless i1 +/// integers. +class BoolAttr : public Attribute { public: - using Base::Base; + using Attribute::Attribute; + using ValueType = bool; - /// Convert endianess of input ArrayRef for big-endian(BE) machines. All of - /// the elements of `inRawData` has `type`. If `inRawData` is little endian - /// (LE), it is converted to big endian (BE). Conversely, if `inRawData` is - /// BE, converted to LE. - static void - convertEndianOfArrayRefForBEmachine(ArrayRef inRawData, - MutableArrayRef outRawData, - ShapedType type); + static BoolAttr get(MLIRContext *context, bool value); - /// Convert endianess of input for big-endian(BE) machines. The number of - /// elements of `inRawData` is `numElements`, and each element has - /// `elementBitWidth` bits. If `inRawData` is little endian (LE), it is - /// converted to big endian (BE) and saved in `outRawData`. Conversely, if - /// `inRawData` is BE, converted to LE. - static void convertEndianOfCharForBEmachine(const char *inRawData, - char *outRawData, - size_t elementBitWidth, - size_t numElements); + /// Enable conversion to IntegerAttr. This uses conversion vs. inheritance to + /// avoid bringing in all of IntegerAttrs methods. + operator IntegerAttr() const { return IntegerAttr(impl); } -protected: - friend DenseElementsAttr; + /// Return the boolean value of this attribute. + bool getValue() const; - /// Constructs a dense elements attribute from an array of raw APFloat values. - /// Each APFloat value is expected to have the same bitwidth as the element - /// type of 'type'. 'type' must be a vector or tensor with static shape. - static DenseElementsAttr getRaw(ShapedType type, size_t storageWidth, - ArrayRef values, bool isSplat); - - /// Constructs a dense elements attribute from an array of raw APInt values. - /// Each APInt value is expected to have the same bitwidth as the element type - /// of 'type'. 'type' must be a vector or tensor with static shape. - static DenseElementsAttr getRaw(ShapedType type, size_t storageWidth, - ArrayRef values, bool isSplat); - - /// Get or create a new dense elements attribute instance with the given raw - /// data buffer. 'type' must be a vector or tensor with static shape. - static DenseElementsAttr getRaw(ShapedType type, ArrayRef data, - bool isSplat); - - /// Overload of the raw 'get' method that asserts that the given type is of - /// complex type. This method is used to verify type invariants that the - /// templatized 'get' method cannot. - static DenseElementsAttr getRawComplex(ShapedType type, ArrayRef data, - int64_t dataEltSize, bool isInt, - bool isSigned); - - /// Overload of the raw 'get' method that asserts that the given type is of - /// integer or floating-point type. This method is used to verify type - /// invariants that the templatized 'get' method cannot. - static DenseElementsAttr getRawIntOrFloat(ShapedType type, - ArrayRef data, - int64_t dataEltSize, bool isInt, - bool isSigned); + /// Methods for support type inquiry through isa, cast, and dyn_cast. + static bool classof(Attribute attr); }; +//===----------------------------------------------------------------------===// +// FlatSymbolRefAttr +//===----------------------------------------------------------------------===// + +/// A symbol reference with a reference path containing a single element. This +/// is used to refer to an operation within the current symbol table. +class FlatSymbolRefAttr : public SymbolRefAttr { +public: + using SymbolRefAttr::SymbolRefAttr; + using ValueType = StringRef; + + /// Construct a symbol reference for the given value name. + static FlatSymbolRefAttr get(MLIRContext *ctx, StringRef value) { + return SymbolRefAttr::get(ctx, value); + } + + /// Returns the name of the held symbol reference. + StringRef getValue() const { return getRootReference(); } + + /// Methods for support type inquiry through isa, cast, and dyn_cast. + static bool classof(Attribute attr) { + SymbolRefAttr refAttr = attr.dyn_cast(); + return refAttr && refAttr.getNestedReferences().empty(); + } + +private: + using SymbolRefAttr::get; + using SymbolRefAttr::getNestedReferences; +}; + +//===----------------------------------------------------------------------===// +// DenseFPElementsAttr +//===----------------------------------------------------------------------===// + /// An attribute that represents a reference to a dense float vector or tensor /// object. Each element is stored as a double. class DenseFPElementsAttr : public DenseIntOrFPElementsAttr { @@ -869,6 +726,10 @@ public: static bool classof(Attribute attr); }; +//===----------------------------------------------------------------------===// +// DenseIntElementsAttr +//===----------------------------------------------------------------------===// + /// An attribute that represents a reference to a dense integer vector or tensor /// object. class DenseIntElementsAttr : public DenseIntOrFPElementsAttr { @@ -906,170 +767,27 @@ public: static bool classof(Attribute attr); }; -/// An opaque attribute that represents a reference to a vector or tensor -/// constant with opaque content. This representation is for tensor constants -/// which the compiler may not need to interpret. This attribute is always -/// associated with a particular dialect, which provides a method to convert -/// tensor representation to a non-opaque format. -class OpaqueElementsAttr - : public Attribute::AttrBase { -public: - using Base::Base; - using ValueType = StringRef; +//===----------------------------------------------------------------------===// +// SparseElementsAttr +//===----------------------------------------------------------------------===// - static OpaqueElementsAttr get(Dialect *dialect, ShapedType type, - StringRef bytes); - - StringRef getValue() const; - - /// Return the value at the given index. The 'index' is expected to refer to a - /// valid element. - Attribute getValue(ArrayRef index) const; - - /// Decodes the attribute value using dialect-specific decoding hook. - /// Returns false if decoding is successful. If not, returns true and leaves - /// 'result' argument unspecified. - bool decode(ElementsAttr &result); - - /// Returns dialect associated with this opaque constant. - Dialect *getDialect() const; -}; - -/// An attribute that represents a reference to a sparse vector or tensor -/// object. -/// -/// This class uses COO (coordinate list) encoding to represent the sparse -/// elements in an element attribute. Specifically, the sparse vector/tensor -/// stores the indices and values as two separate dense elements attributes of -/// tensor type (even if the sparse attribute is of vector type, in order to -/// support empty lists). The dense elements attribute indices is a 2-D tensor -/// of 64-bit integer elements with shape [N, ndims], which specifies the -/// indices of the elements in the sparse tensor that contains nonzero values. -/// The dense elements attribute values is a 1-D tensor with shape [N], and it -/// supplies the corresponding values for the indices. -/// -/// For example, -/// `sparse, [[0, 0], [1, 2]], [1, 5]>` represents tensor -/// [[1, 0, 0, 0], -/// [0, 0, 5, 0], -/// [0, 0, 0, 0]]. -class SparseElementsAttr - : public Attribute::AttrBase { -public: - using Base::Base; - - template - using iterator = - llvm::mapped_iterator, - std::function>; - - /// 'type' must be a vector or tensor with static shape. - static SparseElementsAttr get(ShapedType type, DenseElementsAttr indices, - DenseElementsAttr values); - - DenseIntElementsAttr getIndices() const; - - DenseElementsAttr getValues() const; - - /// Return the values of this attribute in the form of the given type 'T'. 'T' - /// may be any of Attribute, APInt, APFloat, c++ integer/float types, etc. - template llvm::iterator_range> getValues() const { - auto zeroValue = getZeroValue(); - auto valueIt = getValues().getValues().begin(); - const std::vector flatSparseIndices(getFlattenedSparseIndices()); - // TODO: Move-capture flatSparseIndices when c++14 is available. - std::function mapFn = [=](ptrdiff_t index) { - // Try to map the current index to one of the sparse indices. - for (unsigned i = 0, e = flatSparseIndices.size(); i != e; ++i) - if (flatSparseIndices[i] == index) - return *std::next(valueIt, i); - // Otherwise, return the zero value. - return zeroValue; - }; - return llvm::map_range(llvm::seq(0, getNumElements()), mapFn); - } - - /// Return the value of the element at the given index. The 'index' is - /// expected to refer to a valid element. - Attribute getValue(ArrayRef index) const; - -private: - /// Get a zero APFloat for the given sparse attribute. - APFloat getZeroAPFloat() const; - - /// Get a zero APInt for the given sparse attribute. - APInt getZeroAPInt() const; - - /// Get a zero attribute for the given sparse attribute. - Attribute getZeroAttr() const; - - /// Utility methods to generate a zero value of some type 'T'. This is used by - /// the 'iterator' class. - /// Get a zero for a given attribute type. - template - typename std::enable_if::value, T>::type - getZeroValue() const { - return getZeroAttr().template cast(); - } - /// Get a zero for an APInt. - template - typename std::enable_if::value, T>::type - getZeroValue() const { - return getZeroAPInt(); - } - template - typename std::enable_if, T>::value, T>::type - getZeroValue() const { - APInt intZero = getZeroAPInt(); - return {intZero, intZero}; - } - /// Get a zero for an APFloat. - template - typename std::enable_if::value, T>::type - getZeroValue() const { - return getZeroAPFloat(); - } - template - typename std::enable_if, T>::value, - T>::type - getZeroValue() const { - APFloat floatZero = getZeroAPFloat(); - return {floatZero, floatZero}; - } - - /// Get a zero for an C++ integer, float, StringRef, or complex type. - template - typename std::enable_if< - std::numeric_limits::is_integer || - DenseElementsAttr::is_valid_cpp_fp_type::value || - std::is_same::value || - (detail::is_complex_t::value && - !llvm::is_one_of, - std::complex>::value), - T>::type - getZeroValue() const { - return T(); - } - - /// Flatten, and return, all of the sparse indices in this attribute in - /// row-major order. - std::vector getFlattenedSparseIndices() const; -}; - -/// An attribute that represents a reference to a splat vector or tensor -/// constant, meaning all of the elements have the same value. -class SplatElementsAttr : public DenseElementsAttr { -public: - using DenseElementsAttr::DenseElementsAttr; - - /// Method for support type inquiry through isa, cast and dyn_cast. - static bool classof(Attribute attr) { - auto denseAttr = attr.dyn_cast(); - return denseAttr && denseAttr.isSplat(); - } -}; +template +auto SparseElementsAttr::getValues() const + -> llvm::iterator_range> { + auto zeroValue = getZeroValue(); + auto valueIt = getValues().getValues().begin(); + const std::vector flatSparseIndices(getFlattenedSparseIndices()); + // TODO: Move-capture flatSparseIndices when c++14 is available. + std::function mapFn = [=](ptrdiff_t index) { + // Try to map the current index to one of the sparse indices. + for (unsigned i = 0, e = flatSparseIndices.size(); i != e; ++i) + if (flatSparseIndices[i] == index) + return *std::next(valueIt, i); + // Otherwise, return the zero value. + return zeroValue; + }; + return llvm::map_range(llvm::seq(0, getNumElements()), mapFn); +} namespace detail { /// This class represents a general iterator over the values of an ElementsAttr. diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td index 530ab0b79e3f..433c33521a7a 100644 --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -22,7 +22,8 @@ include "mlir/IR/BuiltinDialect.td" // to this file instead. // Base class for Builtin dialect attributes. -class Builtin_Attr : AttrDef { +class Builtin_Attr + : AttrDef { let mnemonic = ?; } @@ -127,6 +128,151 @@ def Builtin_ArrayAttr : Builtin_Attr<"Array"> { }]; } +//===----------------------------------------------------------------------===// +// DenseIntOrFPElementsAttr +//===----------------------------------------------------------------------===// + +def Builtin_DenseIntOrFPElementsAttr + : Builtin_Attr<"DenseIntOrFPElements", "DenseElementsAttr"> { + let summary = "An Attribute containing a dense multi-dimensional array of " + "integer or floating-point values"; + let description = [{ + Syntax: + + ``` + dense-intorfloat-elements-attribute ::= `dense` `<` attribute-value `>` `:` + ( tensor-type | vector-type ) + ``` + + A dense int-or-float elements attribute is an elements attribute containing + a densely packed vector or tensor of integer or floating-point values. The + element type of this attribute is required to be either an `IntegerType` or + a `FloatType`. + + Examples: + + ``` + // A splat tensor of integer values. + dense<10> : tensor<2xi32> + // A tensor of 2 float32 elements. + dense<[10.0, 11.0]> : tensor<2xf32> + ``` + }]; + let parameters = (ins AttributeSelfTypeParameter<"", "ShapedType">:$type, + "ArrayRef":$rawData); + let extraClassDeclaration = [{ + /// Convert endianess of input ArrayRef for big-endian(BE) machines. All of + /// the elements of `inRawData` has `type`. If `inRawData` is little endian + /// (LE), it is converted to big endian (BE). Conversely, if `inRawData` is + /// BE, converted to LE. + static void + convertEndianOfArrayRefForBEmachine(ArrayRef inRawData, + MutableArrayRef outRawData, + ShapedType type); + + /// Convert endianess of input for big-endian(BE) machines. The number of + /// elements of `inRawData` is `numElements`, and each element has + /// `elementBitWidth` bits. If `inRawData` is little endian (LE), it is + /// converted to big endian (BE) and saved in `outRawData`. Conversely, if + /// `inRawData` is BE, converted to LE. + static void convertEndianOfCharForBEmachine(const char *inRawData, + char *outRawData, + size_t elementBitWidth, + size_t numElements); + + protected: + friend DenseElementsAttr; + + /// Constructs a dense elements attribute from an array of raw APFloat + /// values. Each APFloat value is expected to have the same bitwidth as the + /// element type of 'type'. 'type' must be a vector or tensor with static + /// shape. + static DenseElementsAttr getRaw(ShapedType type, size_t storageWidth, + ArrayRef values, bool isSplat); + + /// Constructs a dense elements attribute from an array of raw APInt values. + /// Each APInt value is expected to have the same bitwidth as the element + /// type of 'type'. 'type' must be a vector or tensor with static shape. + static DenseElementsAttr getRaw(ShapedType type, size_t storageWidth, + ArrayRef values, bool isSplat); + + /// Get or create a new dense elements attribute instance with the given raw + /// data buffer. 'type' must be a vector or tensor with static shape. + static DenseElementsAttr getRaw(ShapedType type, ArrayRef data, + bool isSplat); + + /// Overload of the raw 'get' method that asserts that the given type is of + /// complex type. This method is used to verify type invariants that the + /// templatized 'get' method cannot. + static DenseElementsAttr getRawComplex(ShapedType type, ArrayRef data, + int64_t dataEltSize, bool isInt, + bool isSigned); + + /// Overload of the raw 'get' method that asserts that the given type is of + /// integer or floating-point type. This method is used to verify type + /// invariants that the templatized 'get' method cannot. + static DenseElementsAttr getRawIntOrFloat(ShapedType type, + ArrayRef data, + int64_t dataEltSize, bool isInt, + bool isSigned); + + public: + }]; + let genAccessors = 0; + let genStorageClass = 0; + let skipDefaultBuilders = 1; +} + +//===----------------------------------------------------------------------===// +// DenseStringElementsAttr +//===----------------------------------------------------------------------===// + +def Builtin_DenseStringElementsAttr + : Builtin_Attr<"DenseStringElements", "DenseElementsAttr"> { + let summary = "An Attribute containing a dense multi-dimensional array of " + "strings"; + let description = [{ + Syntax: + + ``` + dense-string-elements-attribute ::= `dense` `<` attribute-value `>` `:` + ( tensor-type | vector-type ) + ``` + + A dense string elements attribute is an elements attribute containing a + densely packed vector or tensor of string values. There are no restrictions + placed on the element type of this attribute, enabling the use of dialect + specific string types. + + Examples: + + ``` + // A splat tensor of strings. + dense<"example"> : tensor<2x!foo.string> + // A tensor of 2 string elements. + dense<["example1", "example2"]> : tensor<2x!foo.string> + ``` + }]; + let parameters = (ins AttributeSelfTypeParameter<"", "ShapedType">:$type, + "ArrayRef":$value); + let builders = [ + AttrBuilderWithInferredContext<(ins "ShapedType":$type, + "ArrayRef":$values), [{ + return $_get(type.getContext(), type, values, + /* isSplat */(values.size() == 1)); + }]>, + ]; + let extraClassDeclaration = [{ + protected: + friend DenseElementsAttr; + + public: + }]; + let genAccessors = 0; + let genStorageClass = 0; + let skipDefaultBuilders = 1; +} + //===----------------------------------------------------------------------===// // DictionaryAttr //===----------------------------------------------------------------------===// @@ -220,6 +366,147 @@ def Builtin_DictionaryAttr : Builtin_Attr<"Dictionary"> { let skipDefaultBuilders = 1; } +//===----------------------------------------------------------------------===// +// FloatAttr +//===----------------------------------------------------------------------===// + +def Builtin_FloatAttr : Builtin_Attr<"Float"> { + let summary = "An Attribute containing a floating-point value"; + let description = [{ + Syntax: + + ``` + float-attribute ::= (float-literal (`:` float-type)?) + | (hexadecimal-literal `:` float-type) + ``` + + A float attribute is a literal attribute that represents a floating point + value of the specified [float type](#floating-point-types). It can be + represented in the hexadecimal form where the hexadecimal value is + interpreted as bits of the underlying binary representation. This form is + useful for representing infinity and NaN floating point values. To avoid + confusion with integer attributes, hexadecimal literals _must_ be followed + by a float type to define a float attribute. + + Examples: + + ``` + 42.0 // float attribute defaults to f64 type + 42.0 : f32 // float attribute of f32 type + 0x7C00 : f16 // positive infinity + 0x7CFF : f16 // NaN (one of possible values) + 42 : f32 // Error: expected integer type + ``` + }]; + let parameters = (ins AttributeSelfTypeParameter<"">:$type, + APFloatParameter<"">:$value); + let builders = [ + AttrBuilderWithInferredContext<(ins "Type":$type, + "const APFloat &":$value), [{ + return $_get(type.getContext(), type, value); + }]>, + AttrBuilderWithInferredContext<(ins "Type":$type, "double":$value), [{ + if (type.isF64()) + return $_get(type.getContext(), type, APFloat(value)); + + // This handles, e.g., F16 because there is no APFloat constructor for it. + bool unused; + APFloat val(value); + val.convert(type.cast().getFloatSemantics(), + APFloat::rmNearestTiesToEven, &unused); + return $_get(type.getContext(), type, val); + }]> + ]; + let extraClassDeclaration = [{ + using ValueType = APFloat; + + /// This function is used to convert the value to a double, even if it loses + /// precision. + double getValueAsDouble() const; + static double getValueAsDouble(APFloat val); + }]; + let genVerifyDecl = 1; + let skipDefaultBuilders = 1; +} + +//===----------------------------------------------------------------------===// +// IntegerAttr +//===----------------------------------------------------------------------===// + +def Builtin_IntegerAttr : Builtin_Attr<"Integer"> { + let summary = "An Attribute containing a integer value"; + let description = [{ + Syntax: + + ``` + integer-attribute ::= (integer-literal ( `:` (index-type | integer-type) )?) + | `true` | `false` + ``` + + An integer attribute is a literal attribute that represents an integral + value of the specified integer or index type. `i1` integer attributes are + treated as `boolean` attributes, and use a unique assembly format of either + `true` or `false` depending on the value. The default type for non-boolean + integer attributes, if a type is not specified, is signless 64-bit integer. + + Examples: + + ```mlir + 10 : i32 + 10 // : i64 is implied here. + true // A bool, i.e. i1, value. + false // A bool, i.e. i1, value. + ``` + }]; + let parameters = (ins AttributeSelfTypeParameter<"">:$type, "APInt":$value); + let builders = [ + AttrBuilderWithInferredContext<(ins "Type":$type, + "const APInt &":$value), [{ + if (type.isSignlessInteger(1)) + return BoolAttr::get(type.getContext(), value.getBoolValue()); + return $_get(type.getContext(), type, value); + }]>, + AttrBuilderWithInferredContext<(ins "Type":$type, "int64_t":$value), [{ + // `index` has a defined internal storage width. + if (type.isIndex()) { + APInt apValue(IndexType::kInternalStorageBitWidth, value); + return $_get(type.getContext(), type, apValue); + } + + IntegerType intTy = type.cast(); + APInt apValue(intTy.getWidth(), value, intTy.isSignedInteger()); + return $_get(type.getContext(), type, apValue); + }]> + ]; + let extraClassDeclaration = [{ + using ValueType = APInt; + + /// Return the integer value as a 64-bit int. The attribute must be a + /// signless integer. + // TODO: Change callers to use getValue instead. + int64_t getInt() const; + /// Return the integer value as a signed 64-bit int. The attribute must be + /// a signed integer. + int64_t getSInt() const; + /// Return the integer value as a unsigned 64-bit int. The attribute must be + /// an unsigned integer. + uint64_t getUInt() const; + + private: + /// Return a boolean attribute. This is a special variant of the `get` + /// method that is used by the MLIRContext to cache the boolean IntegerAttr + /// instances. + static BoolAttr getBoolAttrUnchecked(IntegerType type, bool value); + + /// Allow access to `getBoolAttrUnchecked`. + friend MLIRContext; + + public: + }]; + let genVerifyDecl = 1; + let skipDefaultBuilders = 1; +} + //===----------------------------------------------------------------------===// // IntegerSetAttr //===----------------------------------------------------------------------===// @@ -282,8 +569,212 @@ def Builtin_OpaqueAttr : Builtin_Attr<"Opaque"> { return $_get(dialect.getContext(), dialect, attrData, type); }]> ]; - bit genVerifyDecl = 1; - // let skipDefaultBuilders = 1; + let genVerifyDecl = 1; + let skipDefaultBuilders = 1; +} + +//===----------------------------------------------------------------------===// +// OpaqueElementsAttr +//===----------------------------------------------------------------------===// + +def Builtin_OpaqueElementsAttr + : Builtin_Attr<"OpaqueElements", "ElementsAttr"> { + let summary = "An opaque representation of a multi-dimensional array"; + let description = [{ + Syntax: + + ``` + opaque-elements-attribute ::= `opaque` `<` dialect-namespace `,` + hex-string-literal `>` `:` + ( tensor-type | vector-type ) + ``` + + An opaque elements attribute is an elements attribute where the content of + the value is opaque. The representation of the constant stored by this + elements attribute is only understood, and thus decodable, by the dialect + that created it. + + Note: The parsed string literal must be in hexadecimal form. + + Examples: + + ```mlir + opaque<"foo_dialect", "0xDEADBEEF"> : tensor<10xi32> + ``` + }]; + + // TODO: Provide a way to avoid copying content of large opaque + // tensors This will likely require a new reference attribute kind. + let parameters = (ins "Identifier":$dialect, + StringRefParameter<"">:$value, + AttributeSelfTypeParameter<"", "ShapedType">:$type); + let builders = [ + AttrBuilderWithInferredContext<(ins "Identifier":$dialect, + "ShapedType":$type, + "StringRef":$value), [{ + return $_get(dialect.getContext(), dialect, value, type); + }]>, + AttrBuilderWithInferredContext<(ins "Dialect *":$dialect, + "ShapedType":$type, + "StringRef":$value), [{ + MLIRContext *ctxt = dialect->getContext(); + Identifier dialectName = Identifier::get(dialect->getNamespace(), ctxt); + return $_get(ctxt, dialectName, value, type); + }]> + ]; + let extraClassDeclaration = [{ + using ValueType = StringRef; + + /// Return the value at the given index. The 'index' is expected to refer to + /// a valid element. + Attribute getValue(ArrayRef index) const; + + /// Decodes the attribute value using dialect-specific decoding hook. + /// Returns false if decoding is successful. If not, returns true and leaves + /// 'result' argument unspecified. + bool decode(ElementsAttr &result); + + }]; + let genVerifyDecl = 1; + let skipDefaultBuilders = 1; +} + +//===----------------------------------------------------------------------===// +// SparseElementsAttr +//===----------------------------------------------------------------------===// + +def Builtin_SparseElementsAttr + : Builtin_Attr<"SparseElements", "ElementsAttr"> { + let summary = "An opaque representation of a multi-dimensional array"; + let description = [{ + Syntax: + + ``` + sparse-elements-attribute ::= `sparse` `<` attribute-value `,` + attribute-value `>` `:` + ( tensor-type | vector-type ) + ``` + + A sparse elements attribute is an elements attribute that represents a + sparse vector or tensor object. This is where very few of the elements are + non-zero. + + The attribute uses COO (coordinate list) encoding to represent the sparse + elements of the elements attribute. The indices are stored via a 2-D tensor + of 64-bit integer elements with shape [N, ndims], which specifies the + indices of the elements in the sparse tensor that contains non-zero values. + The element values are stored via a 1-D tensor with shape [N], that supplies + the corresponding values for the indices. + + Example: + + ```mlir + sparse<[[0, 0], [1, 2]], [1, 5]> : tensor<3x4xi32> + + // This represents the following tensor: + /// [[1, 0, 0, 0], + /// [0, 0, 5, 0], + /// [0, 0, 0, 0]] + ``` + }]; + + let parameters = (ins AttributeSelfTypeParameter<"", "ShapedType">:$type, + "DenseIntElementsAttr":$indices, + "DenseElementsAttr":$values); + let builders = [ + AttrBuilderWithInferredContext<(ins "ShapedType":$type, + "DenseElementsAttr":$indices, + "DenseElementsAttr":$values), [{ + assert(indices.getType().getElementType().isInteger(64) && + "expected sparse indices to be 64-bit integer values"); + assert((type.isa()) && + "type must be ranked tensor or vector"); + assert(type.hasStaticShape() && "type must have static shape"); + return $_get(type.getContext(), type, + indices.cast(), values); + }]>, + ]; + let extraClassDeclaration = [{ + template + using iterator = + llvm::mapped_iterator, + std::function>; + + /// Return the values of this attribute in the form of the given type 'T'. + /// 'T' may be any of Attribute, APInt, APFloat, c++ integer/float types, + /// etc. + template llvm::iterator_range> getValues() const; + + /// Return the value of the element at the given index. The 'index' is + /// expected to refer to a valid element. + Attribute getValue(ArrayRef index) const; + + private: + /// Get a zero APFloat for the given sparse attribute. + APFloat getZeroAPFloat() const; + + /// Get a zero APInt for the given sparse attribute. + APInt getZeroAPInt() const; + + /// Get a zero attribute for the given sparse attribute. + Attribute getZeroAttr() const; + + /// Utility methods to generate a zero value of some type 'T'. This is used + /// by the 'iterator' class. + /// Get a zero for a given attribute type. + template + typename std::enable_if::value, T>::type + getZeroValue() const { + return getZeroAttr().template cast(); + } + /// Get a zero for an APInt. + template + typename std::enable_if::value, T>::type + getZeroValue() const { + return getZeroAPInt(); + } + template + typename std::enable_if, T>::value, + T>::type + getZeroValue() const { + APInt intZero = getZeroAPInt(); + return {intZero, intZero}; + } + /// Get a zero for an APFloat. + template + typename std::enable_if::value, T>::type + getZeroValue() const { + return getZeroAPFloat(); + } + template + typename std::enable_if, T>::value, + T>::type + getZeroValue() const { + APFloat floatZero = getZeroAPFloat(); + return {floatZero, floatZero}; + } + + /// Get a zero for an C++ integer, float, StringRef, or complex type. + template + typename std::enable_if< + std::numeric_limits::is_integer || + DenseElementsAttr::is_valid_cpp_fp_type::value || + std::is_same::value || + (detail::is_complex_t::value && + !llvm::is_one_of, + std::complex>::value), + T>::type + getZeroValue() const { + return T(); + } + + /// Flatten, and return, all of the sparse indices in this attribute in + /// row-major order. + std::vector getFlattenedSparseIndices() const; + + public: + }]; + let skipDefaultBuilders = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 819badc8b0f4..844f7685df7f 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2723,7 +2723,8 @@ class ArrayRefOfSelfAllocationParameter : // This is a special parameter used for AttrDefs that represents a `mlir::Type` // that is also used as the value `Type` of the attribute. Only one parameter // of the attribute may be of this type. -class AttributeSelfTypeParameter : - AttrOrTypeParameter<"::mlir::Type", desc> {} +class AttributeSelfTypeParameter : + AttrOrTypeParameter {} #endif // OP_BASE diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index b6d327b1c78b..ad2c3c6a9075 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1515,7 +1515,7 @@ static void printSymbolReference(StringRef symbolRef, raw_ostream &os) { // accept the string "elided". The first string must be a registered dialect // name and the latter must be a hex constant. static void printElidedElementsAttr(raw_ostream &os) { - os << R"(opaque<"", "0xDEADBEEF">)"; + os << R"(opaque<"_", "0xDEADBEEF">)"; } void ModulePrinter::printAttribute(Attribute attr, @@ -1610,8 +1610,8 @@ void ModulePrinter::printAttribute(Attribute attr, if (printerFlags.shouldElideElementsAttr(opaqueAttr)) { printElidedElementsAttr(os); } else { - os << "opaque<\"" << opaqueAttr.getDialect()->getNamespace() << "\", "; - os << '"' << "0x" << llvm::toHex(opaqueAttr.getValue()) << "\">"; + os << "opaque<\"" << opaqueAttr.getDialect() << "\", \"0x" + << llvm::toHex(opaqueAttr.getValue()) << "\">"; } } else if (auto intOrFpEltAttr = attr.dyn_cast()) { diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h index 9499a0f84c84..f62d886cacb6 100644 --- a/mlir/lib/IR/AttributeDetail.h +++ b/mlir/lib/IR/AttributeDetail.h @@ -27,113 +27,6 @@ namespace mlir { namespace detail { -/// An attribute representing a floating point value. -struct FloatAttributeStorage final - : public AttributeStorage, - public llvm::TrailingObjects { - using KeyTy = std::pair; - - FloatAttributeStorage(const llvm::fltSemantics &semantics, Type type, - size_t numObjects) - : AttributeStorage(type), semantics(semantics), numObjects(numObjects) {} - - /// Key equality and hash functions. - bool operator==(const KeyTy &key) const { - return key.first == getType() && key.second.bitwiseIsEqual(getValue()); - } - static unsigned hashKey(const KeyTy &key) { - return llvm::hash_combine(key.first, llvm::hash_value(key.second)); - } - - /// Construct a key with a type and double. - static KeyTy getKey(Type type, double value) { - if (type.isF64()) - return KeyTy(type, APFloat(value)); - - // This handles, e.g., F16 because there is no APFloat constructor for it. - bool unused; - APFloat val(value); - val.convert(type.cast().getFloatSemantics(), - APFloat::rmNearestTiesToEven, &unused); - return KeyTy(type, val); - } - - /// Construct a new storage instance. - static FloatAttributeStorage *construct(AttributeStorageAllocator &allocator, - const KeyTy &key) { - const auto &apint = key.second.bitcastToAPInt(); - - // Here one word's bitwidth equals to that of uint64_t. - auto elements = ArrayRef(apint.getRawData(), apint.getNumWords()); - - auto byteSize = - FloatAttributeStorage::totalSizeToAlloc(elements.size()); - auto rawMem = allocator.allocate(byteSize, alignof(FloatAttributeStorage)); - auto result = ::new (rawMem) FloatAttributeStorage( - key.second.getSemantics(), key.first, elements.size()); - std::uninitialized_copy(elements.begin(), elements.end(), - result->getTrailingObjects()); - return result; - } - - /// Returns an APFloat representing the stored value. - APFloat getValue() const { - auto val = APInt(APFloat::getSizeInBits(semantics), - {getTrailingObjects(), numObjects}); - return APFloat(semantics, val); - } - - const llvm::fltSemantics &semantics; - size_t numObjects; -}; - -/// An attribute representing an integral value. -struct IntegerAttributeStorage final - : public AttributeStorage, - public llvm::TrailingObjects { - using KeyTy = std::pair; - - IntegerAttributeStorage(Type type, size_t numObjects) - : AttributeStorage(type), numObjects(numObjects) { - assert((type.isIndex() || type.isa()) && "invalid type"); - } - - /// Key equality and hash functions. - bool operator==(const KeyTy &key) const { - return key == KeyTy(getType(), getValue()); - } - static unsigned hashKey(const KeyTy &key) { - return llvm::hash_combine(key.first, llvm::hash_value(key.second)); - } - - /// Construct a new storage instance. - static IntegerAttributeStorage * - construct(AttributeStorageAllocator &allocator, const KeyTy &key) { - Type type; - APInt value; - std::tie(type, value) = key; - - auto elements = ArrayRef(value.getRawData(), value.getNumWords()); - auto size = - IntegerAttributeStorage::totalSizeToAlloc(elements.size()); - auto rawMem = allocator.allocate(size, alignof(IntegerAttributeStorage)); - auto result = ::new (rawMem) IntegerAttributeStorage(type, elements.size()); - std::uninitialized_copy(elements.begin(), elements.end(), - result->getTrailingObjects()); - return result; - } - - /// Returns an APInt representing the stored value. - APInt getValue() const { - if (getType().isIndex()) - return APInt(64, {getTrailingObjects(), numObjects}); - return APInt(getType().getIntOrFloatBitWidth(), - {getTrailingObjects(), numObjects}); - } - - size_t numObjects; -}; - //===----------------------------------------------------------------------===// // Elements Attributes //===----------------------------------------------------------------------===// @@ -158,10 +51,9 @@ public: }; /// An attribute representing a reference to a dense vector or tensor object. -struct DenseIntOrFPElementsAttributeStorage - : public DenseElementsAttributeStorage { - DenseIntOrFPElementsAttributeStorage(ShapedType ty, ArrayRef data, - bool isSplat = false) +struct DenseIntOrFPElementsAttrStorage : public DenseElementsAttributeStorage { + DenseIntOrFPElementsAttrStorage(ShapedType ty, ArrayRef data, + bool isSplat = false) : DenseElementsAttributeStorage(ty, isSplat), data(data) {} struct KeyTy { @@ -287,7 +179,7 @@ struct DenseIntOrFPElementsAttributeStorage } /// Construct a new storage instance. - static DenseIntOrFPElementsAttributeStorage * + static DenseIntOrFPElementsAttrStorage * construct(AttributeStorageAllocator &allocator, KeyTy key) { // If the data buffer is non-empty, we copy it into the allocator with a // 64-bit alignment. @@ -303,8 +195,8 @@ struct DenseIntOrFPElementsAttributeStorage copy = ArrayRef(rawData, data.size()); } - return new (allocator.allocate()) - DenseIntOrFPElementsAttributeStorage(key.type, copy, key.isSplat); + return new (allocator.allocate()) + DenseIntOrFPElementsAttrStorage(key.type, copy, key.isSplat); } ArrayRef data; @@ -312,10 +204,9 @@ struct DenseIntOrFPElementsAttributeStorage /// An attribute representing a reference to a dense vector or tensor object /// containing strings. -struct DenseStringElementsAttributeStorage - : public DenseElementsAttributeStorage { - DenseStringElementsAttributeStorage(ShapedType ty, ArrayRef data, - bool isSplat = false) +struct DenseStringElementsAttrStorage : public DenseElementsAttributeStorage { + DenseStringElementsAttrStorage(ShapedType ty, ArrayRef data, + bool isSplat = false) : DenseElementsAttributeStorage(ty, isSplat), data(data) {} struct KeyTy { @@ -385,14 +276,14 @@ struct DenseStringElementsAttributeStorage } /// Construct a new storage instance. - static DenseStringElementsAttributeStorage * + static DenseStringElementsAttrStorage * construct(AttributeStorageAllocator &allocator, KeyTy key) { // If the data buffer is non-empty, we copy it into the allocator with a // 64-bit alignment. ArrayRef copy, data = key.data; if (data.empty()) { - return new (allocator.allocate()) - DenseStringElementsAttributeStorage(key.type, copy, key.isSplat); + return new (allocator.allocate()) + DenseStringElementsAttrStorage(key.type, copy, key.isSplat); } int numEntries = key.isSplat ? 1 : data.size(); @@ -421,72 +312,13 @@ struct DenseStringElementsAttributeStorage copy = ArrayRef(reinterpret_cast(rawData), numEntries); - return new (allocator.allocate()) - DenseStringElementsAttributeStorage(key.type, copy, key.isSplat); + return new (allocator.allocate()) + DenseStringElementsAttrStorage(key.type, copy, key.isSplat); } ArrayRef data; }; -/// An attribute representing a reference to a tensor constant with opaque -/// content. -struct OpaqueElementsAttributeStorage : public AttributeStorage { - using KeyTy = std::tuple; - - OpaqueElementsAttributeStorage(Type type, Dialect *dialect, StringRef bytes) - : AttributeStorage(type), dialect(dialect), bytes(bytes) {} - - /// Key equality and hash functions. - bool operator==(const KeyTy &key) const { - return key == std::make_tuple(getType(), dialect, bytes); - } - static unsigned hashKey(const KeyTy &key) { - return llvm::hash_combine(std::get<0>(key), std::get<1>(key), - std::get<2>(key)); - } - - /// Construct a new storage instance. - static OpaqueElementsAttributeStorage * - construct(AttributeStorageAllocator &allocator, KeyTy key) { - // TODO: Provide a way to avoid copying content of large opaque - // tensors This will likely require a new reference attribute kind. - return new (allocator.allocate()) - OpaqueElementsAttributeStorage(std::get<0>(key), std::get<1>(key), - allocator.copyInto(std::get<2>(key))); - } - - Dialect *dialect; - StringRef bytes; -}; - -/// An attribute representing a reference to a sparse vector or tensor object. -struct SparseElementsAttributeStorage : public AttributeStorage { - using KeyTy = std::tuple; - - SparseElementsAttributeStorage(Type type, DenseIntElementsAttr indices, - DenseElementsAttr values) - : AttributeStorage(type), indices(indices), values(values) {} - - /// Key equality and hash functions. - bool operator==(const KeyTy &key) const { - return key == std::make_tuple(getType(), indices, values); - } - static unsigned hashKey(const KeyTy &key) { - return llvm::hash_combine(std::get<0>(key), std::get<1>(key), - std::get<2>(key)); - } - - /// Construct a new storage instance. - static SparseElementsAttributeStorage * - construct(AttributeStorageAllocator &allocator, KeyTy key) { - return new (allocator.allocate()) - SparseElementsAttributeStorage(std::get<0>(key), std::get<1>(key), - std::get<2>(key)); - } - - DenseIntElementsAttr indices; - DenseElementsAttr values; -}; } // namespace detail } // namespace mlir diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp index 5efb8f7c70ff..947ee143c963 100644 --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -202,26 +202,6 @@ DictionaryAttr DictionaryAttr::getEmptyUnchecked(MLIRContext *context) { // FloatAttr //===----------------------------------------------------------------------===// -FloatAttr FloatAttr::get(Type type, double value) { - return Base::get(type.getContext(), type, value); -} - -FloatAttr FloatAttr::getChecked(function_ref emitError, - Type type, double value) { - return Base::getChecked(emitError, type.getContext(), type, value); -} - -FloatAttr FloatAttr::get(Type type, const APFloat &value) { - return Base::get(type.getContext(), type, value); -} - -FloatAttr FloatAttr::getChecked(function_ref emitError, - Type type, const APFloat &value) { - return Base::getChecked(emitError, type.getContext(), type, value); -} - -APFloat FloatAttr::getValue() const { return getImpl()->getValue(); } - double FloatAttr::getValueAsDouble() const { return getValueAsDouble(getValue()); } @@ -234,25 +214,11 @@ double FloatAttr::getValueAsDouble(APFloat value) { return value.convertToDouble(); } -/// Verify construction invariants. -static LogicalResult -verifyFloatTypeInvariants(function_ref emitError, - Type type) { +LogicalResult FloatAttr::verify(function_ref emitError, + Type type, APFloat value) { + // Verify that the type is correct. if (!type.isa()) return emitError() << "expected floating point type"; - return success(); -} - -LogicalResult FloatAttr::verify(function_ref emitError, - Type type, double value) { - return verifyFloatTypeInvariants(emitError, type); -} - -LogicalResult FloatAttr::verify(function_ref emitError, - Type type, const APFloat &value) { - // Verify that the type is correct. - if (failed(verifyFloatTypeInvariants(emitError, type))) - return failure(); // Verify that the type semantics match that of the value. if (&type.cast().getFloatSemantics() != &value.getSemantics()) { @@ -279,72 +245,47 @@ StringRef SymbolRefAttr::getLeafReference() const { // IntegerAttr //===----------------------------------------------------------------------===// -IntegerAttr IntegerAttr::get(Type type, const APInt &value) { - if (type.isSignlessInteger(1)) - return BoolAttr::get(type.getContext(), value.getBoolValue()); - return Base::get(type.getContext(), type, value); -} - -IntegerAttr IntegerAttr::get(Type type, int64_t value) { - // This uses 64 bit APInts by default for index type. - if (type.isIndex()) - return get(type, APInt(IndexType::kInternalStorageBitWidth, value)); - - auto intType = type.cast(); - return get(type, APInt(intType.getWidth(), value, intType.isSignedInteger())); -} - -APInt IntegerAttr::getValue() const { return getImpl()->getValue(); } - int64_t IntegerAttr::getInt() const { - assert((getImpl()->getType().isIndex() || - getImpl()->getType().isSignlessInteger()) && + assert((getType().isIndex() || getType().isSignlessInteger()) && "must be signless integer"); return getValue().getSExtValue(); } int64_t IntegerAttr::getSInt() const { - assert(getImpl()->getType().isSignedInteger() && "must be signed integer"); + assert(getType().isSignedInteger() && "must be signed integer"); return getValue().getSExtValue(); } uint64_t IntegerAttr::getUInt() const { - assert(getImpl()->getType().isUnsignedInteger() && - "must be unsigned integer"); + assert(getType().isUnsignedInteger() && "must be unsigned integer"); return getValue().getZExtValue(); } -static LogicalResult -verifyIntegerTypeInvariants(function_ref emitError, - Type type) { - if (type.isa()) - return success(); - return emitError() << "expected integer or index type"; -} - LogicalResult IntegerAttr::verify(function_ref emitError, - Type type, int64_t value) { - return verifyIntegerTypeInvariants(emitError, type); -} - -LogicalResult IntegerAttr::verify(function_ref emitError, - Type type, const APInt &value) { - if (failed(verifyIntegerTypeInvariants(emitError, type))) - return failure(); - if (auto integerType = type.dyn_cast()) + Type type, APInt value) { + if (IntegerType integerType = type.dyn_cast()) { if (integerType.getWidth() != value.getBitWidth()) return emitError() << "integer type bit width (" << integerType.getWidth() << ") doesn't match value bit width (" << value.getBitWidth() << ")"; - return success(); + return success(); + } + if (type.isa()) + return success(); + return emitError() << "expected integer or index type"; +} + +BoolAttr IntegerAttr::getBoolAttrUnchecked(IntegerType type, bool value) { + auto attr = Base::get(type.getContext(), type, APInt(/*numBits=*/1, value)); + return attr.cast(); } //===----------------------------------------------------------------------===// // BoolAttr bool BoolAttr::getValue() const { - auto *storage = reinterpret_cast(impl); - return storage->getValue().getBoolValue(); + auto *storage = reinterpret_cast(impl); + return storage->value.getBoolValue(); } bool BoolAttr::classof(Attribute attr) { @@ -987,11 +928,11 @@ auto DenseElementsAttr::getComplexFloatValues() const /// Return the raw storage data held by this attribute. ArrayRef DenseElementsAttr::getRawData() const { - return static_cast(impl)->data; + return static_cast(impl)->data; } ArrayRef DenseElementsAttr::getRawStringData() const { - return static_cast(impl)->data; + return static_cast(impl)->data; } /// Return a new DenseElementsAttr that has the same data as the current @@ -1021,15 +962,6 @@ DenseElementsAttr DenseElementsAttr::mapValues( return cast().mapValues(newElementType, mapping); } -//===----------------------------------------------------------------------===// -// DenseStringElementsAttr -//===----------------------------------------------------------------------===// - -DenseStringElementsAttr -DenseStringElementsAttr::get(ShapedType type, ArrayRef values) { - return Base::get(type.getContext(), type, values, (values.size() == 1)); -} - //===----------------------------------------------------------------------===// // DenseIntOrFPElementsAttr //===----------------------------------------------------------------------===// @@ -1254,15 +1186,6 @@ bool DenseIntElementsAttr::classof(Attribute attr) { // OpaqueElementsAttr //===----------------------------------------------------------------------===// -OpaqueElementsAttr OpaqueElementsAttr::get(Dialect *dialect, ShapedType type, - StringRef bytes) { - assert(TensorType::isValidElementType(type.getElementType()) && - "Input element type should be a valid tensor element type"); - return Base::get(type.getContext(), type, dialect, bytes); -} - -StringRef OpaqueElementsAttr::getValue() const { return getImpl()->bytes; } - /// Return the value at the given index. If index does not refer to a valid /// element, then a null attribute is returned. Attribute OpaqueElementsAttr::getValue(ArrayRef index) const { @@ -1270,43 +1193,30 @@ Attribute OpaqueElementsAttr::getValue(ArrayRef index) const { return Attribute(); } -Dialect *OpaqueElementsAttr::getDialect() const { return getImpl()->dialect; } - bool OpaqueElementsAttr::decode(ElementsAttr &result) { - auto *d = getDialect(); - if (!d) + Dialect *dialect = getDialect().getDialect(); + if (!dialect) return true; auto *interface = - d->getRegisteredInterface(); + dialect->getRegisteredInterface(); if (!interface) return true; return failed(interface->decode(*this, result)); } +LogicalResult +OpaqueElementsAttr::verify(function_ref emitError, + Identifier dialect, StringRef value, + ShapedType type) { + if (!Dialect::isValidNamespace(dialect.strref())) + return emitError() << "invalid dialect namespace '" << dialect << "'"; + return success(); +} + //===----------------------------------------------------------------------===// // SparseElementsAttr //===----------------------------------------------------------------------===// -SparseElementsAttr SparseElementsAttr::get(ShapedType type, - DenseElementsAttr indices, - DenseElementsAttr values) { - assert(indices.getType().getElementType().isInteger(64) && - "expected sparse indices to be 64-bit integer values"); - assert((type.isa()) && - "type must be ranked tensor or vector"); - assert(type.hasStaticShape() && "type must have static shape"); - return Base::get(type.getContext(), type, - indices.cast(), values); -} - -DenseIntElementsAttr SparseElementsAttr::getIndices() const { - return getImpl()->indices; -} - -DenseElementsAttr SparseElementsAttr::getValues() const { - return getImpl()->values; -} - /// Return the value of the element at the given index. Attribute SparseElementsAttr::getValue(ArrayRef index) const { assert(isValidIndex(index) && "expected valid multi-dimensional index"); diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index 464c1a7c842f..ddc88047b7ee 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -390,17 +390,13 @@ MLIRContext::MLIRContext(const DialectRegistry ®istry) //// Attributes. //// Note: These must be registered after the types as they may generate one //// of the above types internally. - /// Bool Attributes. - impl->falseAttr = AttributeUniquer::get( - this, impl->int1Ty, APInt(/*numBits=*/1, false)) - .cast(); - impl->trueAttr = AttributeUniquer::get( - this, impl->int1Ty, APInt(/*numBits=*/1, true)) - .cast(); - /// Unit Attribute. - impl->unitAttr = AttributeUniquer::get(this); /// Unknown Location Attribute. impl->unknownLocAttr = AttributeUniquer::get(this); + /// Bool Attributes. + impl->falseAttr = IntegerAttr::getBoolAttrUnchecked(impl->int1Ty, false); + impl->trueAttr = IntegerAttr::getBoolAttrUnchecked(impl->int1Ty, true); + /// Unit Attribute. + impl->unitAttr = AttributeUniquer::get(this); /// The empty dictionary attribute. impl->emptyDictionaryAttr = DictionaryAttr::getEmptyUnchecked(this); diff --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp index 98f74174e5a3..f71f2a21669a 100644 --- a/mlir/lib/Parser/AttributeParser.cpp +++ b/mlir/lib/Parser/AttributeParser.cpp @@ -862,16 +862,7 @@ Attribute Parser::parseOpaqueElementsAttr(Type attrType) { if (getToken().isNot(Token::string)) return (emitError("expected dialect namespace"), nullptr); - auto name = getToken().getStringValue(); - // Lazy load a dialect in the context if there is a possible namespace. - Dialect *dialect = builder.getContext()->getOrLoadDialect(name); - - // TODO: Allow for having an unknown dialect on an opaque - // attribute. Otherwise, it can't be roundtripped without having the dialect - // registered. - if (!dialect) - return (emitError("no registered dialect with namespace '" + name + "'"), - nullptr); + std::string name = getToken().getStringValue(); consumeToken(Token::string); if (parseToken(Token::comma, "expected ','")) @@ -888,7 +879,7 @@ Attribute Parser::parseOpaqueElementsAttr(Type attrType) { std::string data; if (parseElementAttrHexValues(*this, hexTok, data)) return nullptr; - return OpaqueElementsAttr::get(dialect, type, data); + return OpaqueElementsAttr::get(builder.getIdentifier(name), type, data); } /// Shaped type for elements attribute. diff --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp index eea03015d329..1e4f5e4becdd 100644 --- a/mlir/lib/TableGen/AttrOrTypeDef.cpp +++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp @@ -45,10 +45,6 @@ AttrOrTypeDef::AttrOrTypeDef(const llvm::Record *def) : def(def) { } builders.emplace_back(builder); } - } else if (skipDefaultBuilders()) { - PrintFatalError( - def->getLoc(), - "default builders are skipped and no custom builders provided"); } } diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c index b72a6e6cf2fd..beb73102615e 100644 --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -449,7 +449,7 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) { mlirOperationPrintWithFlags(operation, flags, printToStderr, NULL); fprintf(stderr, "\n"); // clang-format off - // CHECK: Op print with all flags: %{{.*}} = "std.constant"() {elts = opaque<"", "0xDEADBEEF"> : tensor<4xi32>, value = 0 : index} : () -> index loc(unknown) + // CHECK: Op print with all flags: %{{.*}} = "std.constant"() {elts = opaque<"_", "0xDEADBEEF"> : tensor<4xi32>, value = 0 : index} : () -> index loc(unknown) // clang-format on mlirOpPrintingFlagsDestroy(flags); diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir index 419c98626521..4c4df915167a 100644 --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -766,21 +766,14 @@ func @elementsattr_malformed_opaque() -> () { func @elementsattr_malformed_opaque1() -> () { ^bb0: - "foo"(){bar = opaque<"", "0xQZz123"> : tensor<1xi8>} : () -> () // expected-error {{expected string containing hex digits starting with `0x`}} + "foo"(){bar = opaque<"_", "0xQZz123"> : tensor<1xi8>} : () -> () // expected-error {{expected string containing hex digits starting with `0x`}} } // ----- func @elementsattr_malformed_opaque2() -> () { ^bb0: - "foo"(){bar = opaque<"", "00abc"> : tensor<1xi8>} : () -> () // expected-error {{expected string containing hex digits starting with `0x`}} -} - -// ----- - -func @elementsattr_malformed_opaque3() -> () { -^bb0: - "foo"(){bar = opaque<"t", "0xabc"> : tensor<1xi8>} : () -> () // expected-error {{no registered dialect with namespace 't'}} + "foo"(){bar = opaque<"_", "00abc"> : tensor<1xi8>} : () -> () // expected-error {{expected string containing hex digits starting with `0x`}} } // ----- @@ -881,7 +874,7 @@ func @type_alias_unknown(!unknown_alias) -> () { // expected-error {{undefined s func @complex_loops() { affine.for %i1 = 1 to 100 { // expected-error @+1 {{expected '"' in string literal}} - "opaqueIntTensor"(){bar = opaque<"", "0x686]> : tensor<2x1x4xi32>} : () -> () + "opaqueIntTensor"(){bar = opaque<"_", "0x686]> : tensor<2x1x4xi32>} : () -> () // ----- diff --git a/mlir/test/IR/pretty-attributes.mlir b/mlir/test/IR/pretty-attributes.mlir index d4ac8e773935..280e32672ea5 100644 --- a/mlir/test/IR/pretty-attributes.mlir +++ b/mlir/test/IR/pretty-attributes.mlir @@ -5,17 +5,17 @@ // tensor which passes don't look at directly, this isn't an issue. // RUN: mlir-opt %s -mlir-elide-elementsattrs-if-larger=2 | mlir-opt -// CHECK: opaque<"", "0xDEADBEEF"> : tensor<3xi32> +// CHECK: opaque<"_", "0xDEADBEEF"> : tensor<3xi32> "test.dense_attr"() {foo.dense_attr = dense<[1, 2, 3]> : tensor<3xi32>} : () -> () // CHECK: dense<[1, 2]> : tensor<2xi32> "test.non_elided_dense_attr"() {foo.dense_attr = dense<[1, 2]> : tensor<2xi32>} : () -> () -// CHECK: opaque<"", "0xDEADBEEF"> : vector<1x1x1xf16> +// CHECK: opaque<"_", "0xDEADBEEF"> : vector<1x1x1xf16> "test.sparse_attr"() {foo.sparse_attr = sparse<[[1, 2, 3]], -2.0> : vector<1x1x1xf16>} : () -> () -// CHECK: opaque<"", "0xDEADBEEF"> : tensor<100xf32> -"test.opaque_attr"() {foo.opaque_attr = opaque<"", "0xEBFE"> : tensor<100xf32> } : () -> () +// CHECK: opaque<"_", "0xDEADBEEF"> : tensor<100xf32> +"test.opaque_attr"() {foo.opaque_attr = opaque<"_", "0xEBFE"> : tensor<100xf32> } : () -> () // CHECK: dense<1> : tensor<3xi32> "test.dense_splat"() {foo.dense_attr = dense<1> : tensor<3xi32>} : () -> () diff --git a/mlir/test/mlir-tblgen/attrdefs.td b/mlir/test/mlir-tblgen/attrdefs.td index 252b9175b05d..fc95fba3c91c 100644 --- a/mlir/test/mlir-tblgen/attrdefs.td +++ b/mlir/test/mlir-tblgen/attrdefs.td @@ -95,7 +95,7 @@ def B_CompoundAttrA : TestAttr<"CompoundA"> { // DEF: return new (allocator.allocate()) // DEF-NEXT: CompoundAAttrStorage(widthOfSomething, exampleTdType, apFloat, dims, inner); -// DEF: ::mlir::Type CompoundAAttr::getInner() const { return getImpl()->getType(); } +// DEF: ::mlir::Type CompoundAAttr::getInner() const { return getImpl()->getType().cast<::mlir::Type>(); } } def C_IndexAttr : TestAttr<"Index"> { diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp index 636d4f8b51ef..a951df92fe18 100644 --- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp @@ -776,15 +776,19 @@ void DefGenerator::emitDefDef(const AttrOrTypeDef &def) { // Generate accessor definitions only if we also generate the storage class. // Otherwise, let the user define the exact accessor definition. if (def.genAccessors() && def.genStorageClass()) { - for (const AttrOrTypeParameter ¶meter : parameters) { - StringRef paramStorageName = isa(parameter) - ? "getType()" - : parameter.getName(); + for (const AttrOrTypeParameter ¶m : parameters) { + SmallString<32> paramStorageName; + if (isa(param)) { + Twine("getType().cast<" + param.getCppType() + ">()") + .toVector(paramStorageName); + } else { + paramStorageName = param.getName(); + } - SmallString<16> name = parameter.getName(); + SmallString<16> name = param.getName(); name[0] = llvm::toUpper(name[0]); os << formatv("{0} {3}::get{1}() const {{ return getImpl()->{2}; }\n", - parameter.getCppType(), name, paramStorageName, + param.getCppType(), name, paramStorageName, def.getCppClassName()); } }