Introduce integer set attribute

- add IntegerSetAttr to Attributes; add parsing and other support for it
  (builder, etc.).

PiperOrigin-RevId: 218804579
This commit is contained in:
Uday Bondhugula 2018-10-25 22:13:03 -07:00 committed by jpienaar
parent 967d934180
commit ea65c695b9
9 changed files with 152 additions and 35 deletions

View File

@ -19,11 +19,13 @@
#define MLIR_IR_ATTRIBUTES_H #define MLIR_IR_ATTRIBUTES_H
#include "mlir/IR/AffineMap.h" #include "mlir/IR/AffineMap.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/Support/LLVM.h" #include "mlir/Support/LLVM.h"
#include "llvm/ADT/APFloat.h" #include "llvm/ADT/APFloat.h"
#include "llvm/Support/TrailingObjects.h" #include "llvm/Support/TrailingObjects.h"
namespace mlir { namespace mlir {
class Function; class Function;
class FunctionType; class FunctionType;
class MLIRContext; class MLIRContext;
@ -39,6 +41,7 @@ struct FloatAttributeStorage;
struct StringAttributeStorage; struct StringAttributeStorage;
struct ArrayAttributeStorage; struct ArrayAttributeStorage;
struct AffineMapAttributeStorage; struct AffineMapAttributeStorage;
struct IntegerSetAttributeStorage;
struct TypeAttributeStorage; struct TypeAttributeStorage;
struct FunctionAttributeStorage; struct FunctionAttributeStorage;
struct ElementsAttributeStorage; struct ElementsAttributeStorage;
@ -66,6 +69,7 @@ public:
Type, Type,
Array, Array,
AffineMap, AffineMap,
IntegerSet,
Function, Function,
SplatElements, SplatElements,
@ -210,6 +214,20 @@ public:
static bool kindof(Kind kind) { return kind == Kind::AffineMap; } static bool kindof(Kind kind) { return kind == Kind::AffineMap; }
}; };
class IntegerSetAttr : public Attribute {
public:
typedef detail::IntegerSetAttributeStorage ImplType;
IntegerSetAttr() = default;
/* implicit */ IntegerSetAttr(Attribute::ImplType *ptr);
static IntegerSetAttr get(IntegerSet value);
IntegerSet getValue() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(Kind kind) { return kind == Kind::IntegerSet; }
};
class TypeAttr : public Attribute { class TypeAttr : public Attribute {
public: public:
typedef detail::TypeAttributeStorage ImplType; typedef detail::TypeAttributeStorage ImplType;

View File

@ -101,6 +101,7 @@ public:
StringAttr getStringAttr(StringRef bytes); StringAttr getStringAttr(StringRef bytes);
ArrayAttr getArrayAttr(ArrayRef<Attribute> value); ArrayAttr getArrayAttr(ArrayRef<Attribute> value);
AffineMapAttr getAffineMapAttr(AffineMap map); AffineMapAttr getAffineMapAttr(AffineMap map);
IntegerSetAttr getIntegerSetAttr(IntegerSet set);
TypeAttr getTypeAttr(Type *type); TypeAttr getTypeAttr(Type *type);
FunctionAttr getFunctionAttr(const Function *value); FunctionAttr getFunctionAttr(const Function *value);
ElementsAttr getSplatElementsAttr(VectorOrTensorType *type, Attribute elt); ElementsAttr getSplatElementsAttr(VectorOrTensorType *type, Attribute elt);

View File

@ -153,6 +153,8 @@ void ModuleState::visitType(const Type *type) {
void ModuleState::visitAttribute(Attribute attr) { void ModuleState::visitAttribute(Attribute attr) {
if (auto mapAttr = attr.dyn_cast<AffineMapAttr>()) { if (auto mapAttr = attr.dyn_cast<AffineMapAttr>()) {
recordAffineMapReference(mapAttr.getValue()); recordAffineMapReference(mapAttr.getValue());
} else if (auto setAttr = attr.dyn_cast<IntegerSetAttr>()) {
recordIntegerSetReference(setAttr.getValue());
} else if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) { } else if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) {
for (auto elt : arrayAttr.getValue()) { for (auto elt : arrayAttr.getValue()) {
visitAttribute(elt); visitAttribute(elt);
@ -429,6 +431,9 @@ void ModulePrinter::printAttribute(Attribute attr) {
case Attribute::Kind::AffineMap: case Attribute::Kind::AffineMap:
printAffineMapReference(attr.cast<AffineMapAttr>().getValue()); printAffineMapReference(attr.cast<AffineMapAttr>().getValue());
break; break;
case Attribute::Kind::IntegerSet:
printIntegerSetReference(attr.cast<IntegerSetAttr>().getValue());
break;
case Attribute::Kind::Type: case Attribute::Kind::Type:
printType(attr.cast<TypeAttr>().getValue()); printType(attr.cast<TypeAttr>().getValue());
break; break;

View File

@ -76,6 +76,11 @@ struct AffineMapAttributeStorage : public AttributeStorage {
AffineMap value; AffineMap value;
}; };
// An attribute representing a reference to an integer set.
struct IntegerSetAttributeStorage : public AttributeStorage {
IntegerSet value;
};
/// An attribute representing a reference to a type. /// An attribute representing a reference to a type.
struct TypeAttributeStorage : public AttributeStorage { struct TypeAttributeStorage : public AttributeStorage {
Type *value; Type *value;

View File

@ -17,7 +17,9 @@
#include "mlir/IR/Attributes.h" #include "mlir/IR/Attributes.h"
#include "AttributeDetail.h" #include "AttributeDetail.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Function.h" #include "mlir/IR/Function.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Types.h" #include "mlir/IR/Types.h"
using namespace mlir; using namespace mlir;
@ -65,6 +67,12 @@ AffineMap AffineMapAttr::getValue() const {
return static_cast<ImplType *>(attr)->value; return static_cast<ImplType *>(attr)->value;
} }
IntegerSetAttr::IntegerSetAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
IntegerSet IntegerSetAttr::getValue() const {
return static_cast<ImplType *>(attr)->value;
}
TypeAttr::TypeAttr(Attribute::ImplType *ptr) : Attribute(ptr) {} TypeAttr::TypeAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
Type *TypeAttr::getValue() const { Type *TypeAttr::getValue() const {

View File

@ -140,6 +140,10 @@ AffineMapAttr Builder::getAffineMapAttr(AffineMap map) {
return AffineMapAttr::get(map); return AffineMapAttr::get(map);
} }
IntegerSetAttr Builder::getIntegerSetAttr(IntegerSet set) {
return IntegerSetAttr::get(set);
}
TypeAttr Builder::getTypeAttr(Type *type) { TypeAttr Builder::getTypeAttr(Type *type) {
return TypeAttr::get(type, context); return TypeAttr::get(type, context);
} }

View File

@ -354,6 +354,7 @@ public:
using ArrayAttrSet = DenseSet<ArrayAttributeStorage *, ArrayAttrKeyInfo>; using ArrayAttrSet = DenseSet<ArrayAttributeStorage *, ArrayAttrKeyInfo>;
ArrayAttrSet arrayAttrs; ArrayAttrSet arrayAttrs;
DenseMap<AffineMap, AffineMapAttributeStorage *> affineMapAttrs; DenseMap<AffineMap, AffineMapAttributeStorage *> affineMapAttrs;
DenseMap<IntegerSet, IntegerSetAttributeStorage *> integerSetAttrs;
DenseMap<Type *, TypeAttributeStorage *> typeAttrs; DenseMap<Type *, TypeAttributeStorage *> typeAttrs;
using AttributeListSet = using AttributeListSet =
DenseSet<AttributeListStorage *, AttributeListKeyInfo>; DenseSet<AttributeListStorage *, AttributeListKeyInfo>;
@ -870,6 +871,19 @@ AffineMapAttr AffineMapAttr::get(AffineMap value) {
return result; return result;
} }
IntegerSetAttr IntegerSetAttr::get(IntegerSet value) {
auto *context = value.getConstraint(0).getContext();
auto &result = context->getImpl().integerSetAttrs[value];
if (result)
return result;
result = context->getImpl().allocator.Allocate<IntegerSetAttributeStorage>();
new (result) IntegerSetAttributeStorage{{Attribute::Kind::IntegerSet,
/*isOrContainsFunction=*/false},
value};
return result;
}
TypeAttr TypeAttr::get(Type *type, MLIRContext *context) { TypeAttr TypeAttr::get(Type *type, MLIRContext *context) {
auto *&result = context->getImpl().typeAttrs[type]; auto *&result = context->getImpl().typeAttrs[type];
if (result) if (result)

View File

@ -201,6 +201,7 @@ public:
// Polyhedral structures. // Polyhedral structures.
void parseAffineStructureInline(AffineMap *map, IntegerSet *set); void parseAffineStructureInline(AffineMap *map, IntegerSet *set);
void parseAffineStructureReference(AffineMap *map, IntegerSet *set);
AffineMap parseAffineMapInline(); AffineMap parseAffineMapInline();
AffineMap parseAffineMapReference(); AffineMap parseAffineMapReference();
IntegerSet parseIntegerSetInline(); IntegerSet parseIntegerSetInline();
@ -873,10 +874,16 @@ Attribute Parser::parseAttribute() {
} }
case Token::hash_identifier: case Token::hash_identifier:
case Token::l_paren: { case Token::l_paren: {
// Try to parse affine map reference. // Try to parse an affine map or an integer set reference.
if (auto affineMap = parseAffineMapReference()) AffineMap map;
return builder.getAffineMapAttr(affineMap); IntegerSet set;
return (emitError("expected constant attribute value"), nullptr); parseAffineStructureReference(&map, &set);
if (map)
return builder.getAffineMapAttr(map);
if (set)
return builder.getIntegerSetAttr(set);
return (emitError("expected affine map or integer set attribute value"),
nullptr);
} }
case Token::at_identifier: { case Token::at_identifier: {
@ -1718,18 +1725,76 @@ AffineMap Parser::parseAffineMapInline() {
return map; return map;
} }
AffineMap Parser::parseAffineMapReference() { /// Parse either an affine map reference or integer set reference.
if (getToken().is(Token::hash_identifier)) { ///
// Parse affine map identifier and verify that it exists. /// affine-structure ::= affine-structure-id | affine-structure-inline
StringRef affineMapId = getTokenSpelling().drop_front(); /// affine-structure-id ::= `#` suffix-id
if (getState().affineMapDefinitions.count(affineMapId) == 0) ///
return (emitError("undefined affine map id '" + affineMapId + "'"), /// affine-structure ::= affine-map | integer-set
AffineMap::Null()); ///
consumeToken(Token::hash_identifier); void Parser::parseAffineStructureReference(AffineMap *map, IntegerSet *set) {
return getState().affineMapDefinitions[affineMapId]; assert((map || set) && "both map and set are non-null");
if (getToken().isNot(Token::hash_identifier)) {
// Try to parse inline affine map or integer set.
return parseAffineStructureInline(map, set);
} }
// Try to parse inline affine map.
return parseAffineMapInline(); // Parse affine map / integer set identifier and verify that it exists.
// Note that an id can't be in both affineMapDefinitions and
// integerSetDefinitions since they use the same sigil '#'.
StringRef affineStructId = getTokenSpelling().drop_front();
if (getState().affineMapDefinitions.count(affineStructId) > 0) {
consumeToken(Token::hash_identifier);
if (map)
*map = getState().affineMapDefinitions[affineStructId];
if (set)
*set = IntegerSet::Null();
return;
}
if (getState().integerSetDefinitions.count(affineStructId) > 0) {
consumeToken(Token::hash_identifier);
if (set)
*set = getState().integerSetDefinitions[affineStructId];
if (map)
*map = AffineMap::Null();
return;
}
// The id isn't among any of the recorded definitions.
// Emit the right message depending on what the caller expected.
if (map && !set)
emitError("undefined affine map id '" + affineStructId + "'");
else if (set && !map)
emitError("undefined integer set id '" + affineStructId + "'");
else if (set && map)
emitError("undefined affine map or integer set id '" + affineStructId +
"'");
if (map)
*map = AffineMap::Null();
if (set)
*set = IntegerSet::Null();
}
/// Parse a reference to an integer set.
/// affine-map ::= affine-map-id | affine-map-inline
/// affine-map-id ::= `#` suffix-id
///
AffineMap Parser::parseAffineMapReference() {
AffineMap map;
parseAffineStructureReference(&map, nullptr);
return map;
}
/// Parse a reference to an integer set.
/// integer-set ::= integer-set-id | integer-set-inline
/// integer-set-id ::= `#` suffix-id
///
IntegerSet Parser::parseIntegerSetReference() {
IntegerSet set;
parseAffineStructureReference(nullptr, &set);
return set;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -2993,24 +3058,6 @@ IntegerSet Parser::parseIntegerSetInline() {
return set; return set;
} }
/// Parse a reference to an integer set.
/// integer-set ::= integer-set-id | integer-set-inline
/// integer-set-id ::= `#` suffix-id
///
IntegerSet Parser::parseIntegerSetReference() {
if (getToken().is(Token::hash_identifier)) {
// Parse integer set identifier and verify that it exists.
StringRef integerSetId = getTokenSpelling().drop_front(1);
if (getState().integerSetDefinitions.count(integerSetId) == 0)
return (emitError("undefined integer set id '" + integerSetId + "'"),
IntegerSet());
consumeToken(Token::hash_identifier);
return getState().integerSetDefinitions[integerSetId];
}
// Try to parse an inline integer set definition.
return parseIntegerSetInline();
}
/// If statement. /// If statement.
/// ///
/// ml-if-head ::= `if` ml-if-cond `{` ml-stmt* `}` /// ml-if-head ::= `if` ml-if-cond `{` ml-stmt* `}`

View File

@ -33,10 +33,16 @@
// CHECK: #map{{[0-9]+}} = (d0)[s0] -> (d0 + s0, d0 - s0) // CHECK: #map{{[0-9]+}} = (d0)[s0] -> (d0 + s0, d0 - s0)
#bound_map2 = (i)[s] -> (i + s, i - s) #bound_map2 = (i)[s] -> (i + s, i - s)
// CHECK-DAG: #set0 = (d0)[s0, s1] : (d0 >= 0, d0 * -1 + s0 >= 0, s0 - 5 == 0, d0 * -1 + s1 + 1 >= 0) // CHECK-DAG: #set{{[0-9]+}} = (d0)[s0, s1] : (d0 >= 0, d0 * -1 + s0 >= 0, s0 - 5 == 0, d0 * -1 + s1 + 1 >= 0)
#set0 = (i)[N, M] : (i >= 0, -i + N >= 0, N - 5 == 0, -i + M + 1 >= 0) #set0 = (i)[N, M] : (i >= 0, -i + N >= 0, N - 5 == 0, -i + M + 1 >= 0)
// CHECK-DAG: #set1 = (d0)[s0] : (d0 - 2 >= 0, d0 * -1 + 4 >= 0) // CHECK-DAG: #set{{[0-9]+}} = (d0, d1)[s0] : (d0 >= 0, d1 >= 0)
#set1 = (d0, d1)[s0] : (d0 >= 0, d1 >= 0)
// CHECK-DAG: #set{{[0-9]+}} = (d0) : (d0 - 1 == 0)
#set2 = (d0) : (d0 - 1 == 0)
// CHECK-DAG: #set{{[0-9]+}} = (d0)[s0] : (d0 - 2 >= 0, d0 * -1 + 4 >= 0)
// CHECK: extfunc @foo(i32, i64) -> f32 // CHECK: extfunc @foo(i32, i64) -> f32
extfunc @foo(i32, i64) -> f32 extfunc @foo(i32, i64) -> f32
@ -291,6 +297,15 @@ bb42: // CHECK: bb0:
// CHECK: "foo"() {map12: [#map{{[0-9]+}}, #map{{[0-9]+}}]} // CHECK: "foo"() {map12: [#map{{[0-9]+}}, #map{{[0-9]+}}]}
"foo"() {map12: [#map1, #map2]} : () -> () "foo"() {map12: [#map1, #map2]} : () -> ()
// CHECK: "foo"() {set1: #set{{[0-9]+}}}
"foo"() {set1: #set1} : () -> ()
// CHECK: "foo"() {set2: #set{{[0-9]+}}}
"foo"() {set2: (d0, d1, d2) : (d0 >= 0, d1 >= 0, d2 - d1 == 0)} : () -> ()
// CHECK: "foo"() {set12: [#set{{[0-9]+}}, #set{{[0-9]+}}]}
"foo"() {set12: [#set1, #set2]} : () -> ()
// CHECK: "foo"() {cfgfunc: [], d: 1.000000e-09, i123: 7, if: "foo"} : () -> () // CHECK: "foo"() {cfgfunc: [], d: 1.000000e-09, i123: 7, if: "foo"} : () -> ()
"foo"() {if: "foo", cfgfunc: [], i123: 7, d: 1.e-9} : () -> () "foo"() {if: "foo", cfgfunc: [], i123: 7, d: 1.e-9} : () -> ()