diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h index 961bbab94859..ebe05c328958 100644 --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -19,11 +19,13 @@ #define MLIR_IR_ATTRIBUTES_H #include "mlir/IR/AffineMap.h" +#include "mlir/IR/IntegerSet.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/APFloat.h" #include "llvm/Support/TrailingObjects.h" namespace mlir { + class Function; class FunctionType; class MLIRContext; @@ -39,6 +41,7 @@ struct FloatAttributeStorage; struct StringAttributeStorage; struct ArrayAttributeStorage; struct AffineMapAttributeStorage; +struct IntegerSetAttributeStorage; struct TypeAttributeStorage; struct FunctionAttributeStorage; struct ElementsAttributeStorage; @@ -66,6 +69,7 @@ public: Type, Array, AffineMap, + IntegerSet, Function, SplatElements, @@ -210,6 +214,20 @@ public: static bool kindof(Kind kind) { return kind == Kind::AffineMap; } }; +class IntegerSetAttr : public Attribute { +public: + typedef detail::IntegerSetAttributeStorage ImplType; + IntegerSetAttr() = default; + /* implicit */ IntegerSetAttr(Attribute::ImplType *ptr); + + static IntegerSetAttr get(IntegerSet value); + + IntegerSet getValue() const; + + /// Methods for support type inquiry through isa, cast, and dyn_cast. + static bool kindof(Kind kind) { return kind == Kind::IntegerSet; } +}; + class TypeAttr : public Attribute { public: typedef detail::TypeAttributeStorage ImplType; diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index bce309d9c9a6..2e48008c651a 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -101,6 +101,7 @@ public: StringAttr getStringAttr(StringRef bytes); ArrayAttr getArrayAttr(ArrayRef value); AffineMapAttr getAffineMapAttr(AffineMap map); + IntegerSetAttr getIntegerSetAttr(IntegerSet set); TypeAttr getTypeAttr(Type *type); FunctionAttr getFunctionAttr(const Function *value); ElementsAttr getSplatElementsAttr(VectorOrTensorType *type, Attribute elt); diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index b3f1c494cb7b..454a28a65585 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -153,6 +153,8 @@ void ModuleState::visitType(const Type *type) { void ModuleState::visitAttribute(Attribute attr) { if (auto mapAttr = attr.dyn_cast()) { recordAffineMapReference(mapAttr.getValue()); + } else if (auto setAttr = attr.dyn_cast()) { + recordIntegerSetReference(setAttr.getValue()); } else if (auto arrayAttr = attr.dyn_cast()) { for (auto elt : arrayAttr.getValue()) { visitAttribute(elt); @@ -429,6 +431,9 @@ void ModulePrinter::printAttribute(Attribute attr) { case Attribute::Kind::AffineMap: printAffineMapReference(attr.cast().getValue()); break; + case Attribute::Kind::IntegerSet: + printIntegerSetReference(attr.cast().getValue()); + break; case Attribute::Kind::Type: printType(attr.cast().getValue()); break; diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h index 82acec70a4f5..e0e9663b837d 100644 --- a/mlir/lib/IR/AttributeDetail.h +++ b/mlir/lib/IR/AttributeDetail.h @@ -76,6 +76,11 @@ struct AffineMapAttributeStorage : public AttributeStorage { AffineMap value; }; +// An attribute representing a reference to an integer set. +struct IntegerSetAttributeStorage : public AttributeStorage { + IntegerSet value; +}; + /// An attribute representing a reference to a type. struct TypeAttributeStorage : public AttributeStorage { Type *value; diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp index 8c1a5d3bb3ea..34312b84a0bc 100644 --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -17,7 +17,9 @@ #include "mlir/IR/Attributes.h" #include "AttributeDetail.h" +#include "mlir/IR/AffineMap.h" #include "mlir/IR/Function.h" +#include "mlir/IR/IntegerSet.h" #include "mlir/IR/Types.h" using namespace mlir; @@ -65,6 +67,12 @@ AffineMap AffineMapAttr::getValue() const { return static_cast(attr)->value; } +IntegerSetAttr::IntegerSetAttr(Attribute::ImplType *ptr) : Attribute(ptr) {} + +IntegerSet IntegerSetAttr::getValue() const { + return static_cast(attr)->value; +} + TypeAttr::TypeAttr(Attribute::ImplType *ptr) : Attribute(ptr) {} Type *TypeAttr::getValue() const { diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 5471cc50c163..22d749a6c8c7 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -140,6 +140,10 @@ AffineMapAttr Builder::getAffineMapAttr(AffineMap map) { return AffineMapAttr::get(map); } +IntegerSetAttr Builder::getIntegerSetAttr(IntegerSet set) { + return IntegerSetAttr::get(set); +} + TypeAttr Builder::getTypeAttr(Type *type) { return TypeAttr::get(type, context); } diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index f6d236211d4a..0b660b2291ef 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -354,6 +354,7 @@ public: using ArrayAttrSet = DenseSet; ArrayAttrSet arrayAttrs; DenseMap affineMapAttrs; + DenseMap integerSetAttrs; DenseMap typeAttrs; using AttributeListSet = DenseSet; @@ -870,6 +871,19 @@ AffineMapAttr AffineMapAttr::get(AffineMap value) { return result; } +IntegerSetAttr IntegerSetAttr::get(IntegerSet value) { + auto *context = value.getConstraint(0).getContext(); + auto &result = context->getImpl().integerSetAttrs[value]; + if (result) + return result; + + result = context->getImpl().allocator.Allocate(); + new (result) IntegerSetAttributeStorage{{Attribute::Kind::IntegerSet, + /*isOrContainsFunction=*/false}, + value}; + return result; +} + TypeAttr TypeAttr::get(Type *type, MLIRContext *context) { auto *&result = context->getImpl().typeAttrs[type]; if (result) diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 643999a09c72..1950572b0ceb 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -201,6 +201,7 @@ public: // Polyhedral structures. void parseAffineStructureInline(AffineMap *map, IntegerSet *set); + void parseAffineStructureReference(AffineMap *map, IntegerSet *set); AffineMap parseAffineMapInline(); AffineMap parseAffineMapReference(); IntegerSet parseIntegerSetInline(); @@ -873,10 +874,16 @@ Attribute Parser::parseAttribute() { } case Token::hash_identifier: case Token::l_paren: { - // Try to parse affine map reference. - if (auto affineMap = parseAffineMapReference()) - return builder.getAffineMapAttr(affineMap); - return (emitError("expected constant attribute value"), nullptr); + // Try to parse an affine map or an integer set reference. + AffineMap map; + IntegerSet set; + parseAffineStructureReference(&map, &set); + if (map) + return builder.getAffineMapAttr(map); + if (set) + return builder.getIntegerSetAttr(set); + return (emitError("expected affine map or integer set attribute value"), + nullptr); } case Token::at_identifier: { @@ -1718,18 +1725,76 @@ AffineMap Parser::parseAffineMapInline() { return map; } -AffineMap Parser::parseAffineMapReference() { - if (getToken().is(Token::hash_identifier)) { - // Parse affine map identifier and verify that it exists. - StringRef affineMapId = getTokenSpelling().drop_front(); - if (getState().affineMapDefinitions.count(affineMapId) == 0) - return (emitError("undefined affine map id '" + affineMapId + "'"), - AffineMap::Null()); - consumeToken(Token::hash_identifier); - return getState().affineMapDefinitions[affineMapId]; +/// Parse either an affine map reference or integer set reference. +/// +/// affine-structure ::= affine-structure-id | affine-structure-inline +/// affine-structure-id ::= `#` suffix-id +/// +/// affine-structure ::= affine-map | integer-set +/// +void Parser::parseAffineStructureReference(AffineMap *map, IntegerSet *set) { + assert((map || set) && "both map and set are non-null"); + if (getToken().isNot(Token::hash_identifier)) { + // Try to parse inline affine map or integer set. + return parseAffineStructureInline(map, set); } - // Try to parse inline affine map. - return parseAffineMapInline(); + + // Parse affine map / integer set identifier and verify that it exists. + // Note that an id can't be in both affineMapDefinitions and + // integerSetDefinitions since they use the same sigil '#'. + StringRef affineStructId = getTokenSpelling().drop_front(); + if (getState().affineMapDefinitions.count(affineStructId) > 0) { + consumeToken(Token::hash_identifier); + if (map) + *map = getState().affineMapDefinitions[affineStructId]; + if (set) + *set = IntegerSet::Null(); + return; + } + + if (getState().integerSetDefinitions.count(affineStructId) > 0) { + consumeToken(Token::hash_identifier); + if (set) + *set = getState().integerSetDefinitions[affineStructId]; + if (map) + *map = AffineMap::Null(); + return; + } + + // The id isn't among any of the recorded definitions. + // Emit the right message depending on what the caller expected. + if (map && !set) + emitError("undefined affine map id '" + affineStructId + "'"); + else if (set && !map) + emitError("undefined integer set id '" + affineStructId + "'"); + else if (set && map) + emitError("undefined affine map or integer set id '" + affineStructId + + "'"); + + if (map) + *map = AffineMap::Null(); + if (set) + *set = IntegerSet::Null(); +} + +/// Parse a reference to an integer set. +/// affine-map ::= affine-map-id | affine-map-inline +/// affine-map-id ::= `#` suffix-id +/// +AffineMap Parser::parseAffineMapReference() { + AffineMap map; + parseAffineStructureReference(&map, nullptr); + return map; +} + +/// Parse a reference to an integer set. +/// integer-set ::= integer-set-id | integer-set-inline +/// integer-set-id ::= `#` suffix-id +/// +IntegerSet Parser::parseIntegerSetReference() { + IntegerSet set; + parseAffineStructureReference(nullptr, &set); + return set; } //===----------------------------------------------------------------------===// @@ -2993,24 +3058,6 @@ IntegerSet Parser::parseIntegerSetInline() { return set; } -/// Parse a reference to an integer set. -/// integer-set ::= integer-set-id | integer-set-inline -/// integer-set-id ::= `#` suffix-id -/// -IntegerSet Parser::parseIntegerSetReference() { - if (getToken().is(Token::hash_identifier)) { - // Parse integer set identifier and verify that it exists. - StringRef integerSetId = getTokenSpelling().drop_front(1); - if (getState().integerSetDefinitions.count(integerSetId) == 0) - return (emitError("undefined integer set id '" + integerSetId + "'"), - IntegerSet()); - consumeToken(Token::hash_identifier); - return getState().integerSetDefinitions[integerSetId]; - } - // Try to parse an inline integer set definition. - return parseIntegerSetInline(); -} - /// If statement. /// /// ml-if-head ::= `if` ml-if-cond `{` ml-stmt* `}` diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir index 989b60733556..ea67c7efce53 100644 --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -33,10 +33,16 @@ // CHECK: #map{{[0-9]+}} = (d0)[s0] -> (d0 + s0, d0 - s0) #bound_map2 = (i)[s] -> (i + s, i - s) -// CHECK-DAG: #set0 = (d0)[s0, s1] : (d0 >= 0, d0 * -1 + s0 >= 0, s0 - 5 == 0, d0 * -1 + s1 + 1 >= 0) +// CHECK-DAG: #set{{[0-9]+}} = (d0)[s0, s1] : (d0 >= 0, d0 * -1 + s0 >= 0, s0 - 5 == 0, d0 * -1 + s1 + 1 >= 0) #set0 = (i)[N, M] : (i >= 0, -i + N >= 0, N - 5 == 0, -i + M + 1 >= 0) -// CHECK-DAG: #set1 = (d0)[s0] : (d0 - 2 >= 0, d0 * -1 + 4 >= 0) +// CHECK-DAG: #set{{[0-9]+}} = (d0, d1)[s0] : (d0 >= 0, d1 >= 0) +#set1 = (d0, d1)[s0] : (d0 >= 0, d1 >= 0) + +// CHECK-DAG: #set{{[0-9]+}} = (d0) : (d0 - 1 == 0) +#set2 = (d0) : (d0 - 1 == 0) + +// CHECK-DAG: #set{{[0-9]+}} = (d0)[s0] : (d0 - 2 >= 0, d0 * -1 + 4 >= 0) // CHECK: extfunc @foo(i32, i64) -> f32 extfunc @foo(i32, i64) -> f32 @@ -291,6 +297,15 @@ bb42: // CHECK: bb0: // CHECK: "foo"() {map12: [#map{{[0-9]+}}, #map{{[0-9]+}}]} "foo"() {map12: [#map1, #map2]} : () -> () + // CHECK: "foo"() {set1: #set{{[0-9]+}}} + "foo"() {set1: #set1} : () -> () + + // CHECK: "foo"() {set2: #set{{[0-9]+}}} + "foo"() {set2: (d0, d1, d2) : (d0 >= 0, d1 >= 0, d2 - d1 == 0)} : () -> () + + // CHECK: "foo"() {set12: [#set{{[0-9]+}}, #set{{[0-9]+}}]} + "foo"() {set12: [#set1, #set2]} : () -> () + // CHECK: "foo"() {cfgfunc: [], d: 1.000000e-09, i123: 7, if: "foo"} : () -> () "foo"() {if: "foo", cfgfunc: [], i123: 7, d: 1.e-9} : () -> ()