[mlir:PDL] Allow non-bound pdl.attribute/pdl.type operations that create constants

This allows for passing in these attributes/types to constraints/rewrites as arguments.

Differential Revision: https://reviews.llvm.org/D114817
This commit is contained in:
River Riddle 2021-12-10 19:36:21 +00:00
parent 06c3b9c7be
commit 233e9476d8
7 changed files with 172 additions and 16 deletions

View File

@ -512,6 +512,13 @@ def PDLInterp_CreateTypesOp : PDLInterp_Op<"create_types", [NoSideEffect]> {
let arguments = (ins TypeArrayAttr:$value);
let results = (outs PDL_RangeOf<PDL_Type>:$result);
let assemblyFormat = "$value attr-dict";
let builders = [
OpBuilder<(ins "ArrayAttr":$type), [{
build($_builder, $_state,
pdl::RangeType::get($_builder.getType<pdl::TypeType>()), type);
}]>
];
}
//===----------------------------------------------------------------------===//

View File

@ -237,10 +237,12 @@ Value PatternLowering::getValueAt(Block *&currentBlock, Position *pos) {
return val;
// Get the value for the parent position.
Value parentVal = getValueAt(currentBlock, pos->getParent());
Value parentVal;
if (Position *parent = pos->getParent())
parentVal = getValueAt(currentBlock, pos->getParent());
// TODO: Use a location from the position.
Location loc = parentVal.getLoc();
Location loc = parentVal ? parentVal.getLoc() : builder.getUnknownLoc();
builder.setInsertionPointToEnd(currentBlock);
Value value;
switch (pos->getKind()) {
@ -331,6 +333,22 @@ Value PatternLowering::getValueAt(Block *&currentBlock, Position *pos) {
parentVal, resPos->getResultGroupNumber());
break;
}
case Predicates::AttributeLiteralPos: {
auto *attrPos = cast<AttributeLiteralPosition>(pos);
value =
builder.create<pdl_interp::CreateAttributeOp>(loc, attrPos->getValue());
break;
}
case Predicates::TypeLiteralPos: {
auto *typePos = cast<TypeLiteralPosition>(pos);
Attribute rawTypeAttr = typePos->getValue();
if (TypeAttr typeAttr = rawTypeAttr.dyn_cast<TypeAttr>())
value = builder.create<pdl_interp::CreateTypeOp>(loc, typeAttr);
else
value = builder.create<pdl_interp::CreateTypesOp>(
loc, rawTypeAttr.cast<ArrayAttr>());
break;
}
default:
llvm_unreachable("Generating unknown Position getter");
break;
@ -353,7 +371,7 @@ void PatternLowering::generate(BoolNode *boolNode, Block *&currentBlock,
if (auto *equalToQuestion = dyn_cast<EqualToQuestion>(question)) {
args = {getValueAt(currentBlock, equalToQuestion->getValue())};
} else if (auto *cstQuestion = dyn_cast<ConstraintQuestion>(question)) {
for (Position *position : std::get<1>(cstQuestion->getValue()))
for (Position *position : cstQuestion->getArgs())
args.push_back(getValueAt(currentBlock, position));
}
@ -413,10 +431,10 @@ void PatternLowering::generate(BoolNode *boolNode, Block *&currentBlock,
break;
}
case Predicates::ConstraintQuestion: {
auto value = cast<ConstraintQuestion>(question)->getValue();
auto *cstQuestion = cast<ConstraintQuestion>(question);
builder.create<pdl_interp::ApplyConstraintOp>(
loc, std::get<0>(value), args, std::get<2>(value).cast<ArrayAttr>(),
success, failure);
loc, cstQuestion->getName(), args, cstQuestion->getParams(), success,
failure);
break;
}
default:

View File

@ -21,7 +21,7 @@ Position::~Position() {}
unsigned Position::getOperationDepth() const {
if (const auto *operationPos = dyn_cast<OperationPosition>(this))
return operationPos->getDepth();
return parent->getOperationDepth();
return parent ? parent->getOperationDepth() : 0;
}
//===----------------------------------------------------------------------===//

View File

@ -50,6 +50,8 @@ enum Kind : unsigned {
ResultPos,
ResultGroupPos,
TypePos,
AttributeLiteralPos,
TypeLiteralPos,
// Questions, ordered by dependency and decreasing priority.
IsNotNullQuestion,
@ -173,6 +175,16 @@ struct AttributePosition
StringAttr getName() const { return key.second; }
};
//===----------------------------------------------------------------------===//
// AttributeLiteralPosition
/// A position describing a literal attribute.
struct AttributeLiteralPosition
: public PredicateBase<AttributeLiteralPosition, Position, Attribute,
Predicates::AttributeLiteralPos> {
using PredicateBase::PredicateBase;
};
//===----------------------------------------------------------------------===//
// OperandPosition
@ -317,6 +329,17 @@ struct TypePosition : public PredicateBase<TypePosition, Position, Position *,
}
};
//===----------------------------------------------------------------------===//
// TypeLiteralPosition
/// A position describing a literal type or type range. The value is stored as
/// either a TypeAttr, or an ArrayAttr of TypeAttr.
struct TypeLiteralPosition
: public PredicateBase<TypeLiteralPosition, Position, Attribute,
Predicates::TypeLiteralPos> {
using PredicateBase::PredicateBase;
};
//===----------------------------------------------------------------------===//
// Qualifiers
//===----------------------------------------------------------------------===//
@ -404,6 +427,17 @@ struct ConstraintQuestion
Predicates::ConstraintQuestion> {
using Base::Base;
/// Return the name of the constraint.
StringRef getName() const { return std::get<0>(key); }
/// Return the arguments of the constraint.
ArrayRef<Position *> getArgs() const { return std::get<1>(key); }
/// Return the constant parameters of the constraint.
ArrayAttr getParams() const {
return std::get<2>(key).dyn_cast_or_null<ArrayAttr>();
}
/// Construct an instance with the given storage allocator.
static ConstraintQuestion *construct(StorageUniquer::StorageAllocator &alloc,
KeyTy key) {
@ -461,12 +495,14 @@ public:
PredicateUniquer() {
// Register the types of Positions with the uniquer.
registerParametricStorageType<AttributePosition>();
registerParametricStorageType<AttributeLiteralPosition>();
registerParametricStorageType<OperandPosition>();
registerParametricStorageType<OperandGroupPosition>();
registerParametricStorageType<OperationPosition>();
registerParametricStorageType<ResultPosition>();
registerParametricStorageType<ResultGroupPosition>();
registerParametricStorageType<TypePosition>();
registerParametricStorageType<TypeLiteralPosition>();
// Register the types of Questions with the uniquer.
registerParametricStorageType<AttributeAnswer>();
@ -527,6 +563,11 @@ public:
return AttributePosition::get(uniquer, p, StringAttr::get(ctx, name));
}
/// Returns an attribute position for the given attribute.
Position *getAttributeLiteral(Attribute attr) {
return AttributeLiteralPosition::get(uniquer, attr);
}
/// Returns an operand position for an operand of the given operation.
Position *getOperand(OperationPosition *p, unsigned operand) {
return OperandPosition::get(uniquer, p, operand);
@ -558,6 +599,12 @@ public:
/// Returns a type position for the given entity.
Position *getType(Position *p) { return TypePosition::get(uniquer, p); }
/// Returns a type position for the given type value. The value is stored
/// as either a TypeAttr, or an ArrayAttr of TypeAttr.
Position *getTypeLiteral(Attribute attr) {
return TypeLiteralPosition::get(uniquer, attr);
}
//===--------------------------------------------------------------------===//
// Qualifiers
//===--------------------------------------------------------------------===//

View File

@ -243,8 +243,18 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList,
.Default([](auto *) { llvm_unreachable("unexpected position kind"); });
}
/// Collect all of the predicates related to constraints within the given
/// pattern operation.
static void getAttributePredicates(pdl::AttributeOp op,
std::vector<PositionalPredicate> &predList,
PredicateBuilder &builder,
DenseMap<Value, Position *> &inputs) {
Position *&attrPos = inputs[op];
if (attrPos)
return;
Attribute value = op.valueAttr();
assert(value && "expected non-tree `pdl.attribute` to contain a value");
attrPos = builder.getAttributeLiteral(value);
}
static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op,
std::vector<PositionalPredicate> &predList,
PredicateBuilder &builder,
@ -296,6 +306,19 @@ static void getResultPredicates(pdl::ResultsOp op,
predList.emplace_back(resultPos, builder.getIsNotNull());
}
static void getTypePredicates(Value typeValue,
function_ref<Attribute()> typeAttrFn,
PredicateBuilder &builder,
DenseMap<Value, Position *> &inputs) {
Position *&typePos = inputs[typeValue];
if (typePos)
return;
Attribute typeAttr = typeAttrFn();
assert(typeAttr &&
"expected non-tree `pdl.type`/`pdl.types` to contain a value");
typePos = builder.getTypeLiteral(typeAttr);
}
/// Collect all of the predicates that cannot be determined via walking the
/// tree.
static void getNonTreePredicates(pdl::PatternOp pattern,
@ -304,11 +327,22 @@ static void getNonTreePredicates(pdl::PatternOp pattern,
DenseMap<Value, Position *> &inputs) {
for (Operation &op : pattern.body().getOps()) {
TypeSwitch<Operation *>(&op)
.Case([&](pdl::AttributeOp attrOp) {
getAttributePredicates(attrOp, predList, builder, inputs);
})
.Case<pdl::ApplyNativeConstraintOp>([&](auto constraintOp) {
getConstraintPredicates(constraintOp, predList, builder, inputs);
})
.Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) {
getResultPredicates(resultOp, predList, builder, inputs);
})
.Case([&](pdl::TypeOp typeOp) {
getTypePredicates(
typeOp, [&] { return typeOp.typeAttr(); }, builder, inputs);
})
.Case([&](pdl::TypesOp typeOp) {
getTypePredicates(
typeOp, [&] { return typeOp.typesAttr(); }, builder, inputs);
});
}
}

View File

@ -114,12 +114,15 @@ static LogicalResult verify(AttributeOp op) {
Value attrType = op.type();
Optional<Attribute> attrValue = op.value();
if (!attrValue && isa<RewriteOp>(op->getParentOp()))
return op.emitOpError("expected constant value when specified within a "
"`pdl.rewrite`");
if (attrValue && attrType)
if (!attrValue) {
if (isa<RewriteOp>(op->getParentOp()))
return op.emitOpError("expected constant value when specified within a "
"`pdl.rewrite`");
return verifyHasBindingUse(op);
}
if (attrType)
return op.emitOpError("expected only one of [`type`, `value`] to be set");
return verifyHasBindingUse(op);
return success();
}
//===----------------------------------------------------------------------===//
@ -431,13 +434,21 @@ static LogicalResult verify(RewriteOp op) {
// pdl::TypeOp
//===----------------------------------------------------------------------===//
static LogicalResult verify(TypeOp op) { return verifyHasBindingUse(op); }
static LogicalResult verify(TypeOp op) {
if (!op.typeAttr())
return verifyHasBindingUse(op);
return success();
}
//===----------------------------------------------------------------------===//
// pdl::TypesOp
//===----------------------------------------------------------------------===//
static LogicalResult verify(TypesOp op) { return verifyHasBindingUse(op); }
static LogicalResult verify(TypesOp op) {
if (!op.typesAttr())
return verifyHasBindingUse(op);
return success();
}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions

View File

@ -573,3 +573,42 @@ module @variadic_results_at {
pdl.rewrite with "rewriter"(%root1, %root2 : !pdl.operation, !pdl.operation)
}
}
// -----
// CHECK-LABEL: module @attribute_literal
module @attribute_literal {
// CHECK: func @matcher(%{{.*}}: !pdl.operation)
// CHECK: %[[ATTR:.*]] = pdl_interp.create_attribute 10 : i64
// CHECK: pdl_interp.apply_constraint "constraint"(%[[ATTR]] : !pdl.attribute)
// Check the correct lowering of an attribute that hasn't been bound.
pdl.pattern : benefit(1) {
%attr = pdl.attribute 10
pdl.apply_native_constraint "constraint"(%attr: !pdl.attribute)
%root = pdl.operation
pdl.rewrite %root with "rewriter"
}
}
// -----
// CHECK-LABEL: module @type_literal
module @type_literal {
// CHECK: func @matcher(%{{.*}}: !pdl.operation)
// CHECK: %[[TYPE:.*]] = pdl_interp.create_type i32
// CHECK: %[[TYPES:.*]] = pdl_interp.create_types [i32, i64]
// CHECK: pdl_interp.apply_constraint "constraint"(%[[TYPE]], %[[TYPES]] : !pdl.type, !pdl.range<type>)
// Check the correct lowering of a type that hasn't been bound.
pdl.pattern : benefit(1) {
%type = pdl.type : i32
%types = pdl.types : [i32, i64]
pdl.apply_native_constraint "constraint"(%type, %types: !pdl.type, !pdl.range<type>)
%root = pdl.operation
pdl.rewrite %root with "rewriter"
}
}