forked from OSchip/llvm-project
[mlir:PDL] Allow non-bound pdl.attribute/pdl.type operations that create constants
This allows for passing in these attributes/types to constraints/rewrites as arguments. Differential Revision: https://reviews.llvm.org/D114817
This commit is contained in:
parent
06c3b9c7be
commit
233e9476d8
|
@ -512,6 +512,13 @@ def PDLInterp_CreateTypesOp : PDLInterp_Op<"create_types", [NoSideEffect]> {
|
|||
let arguments = (ins TypeArrayAttr:$value);
|
||||
let results = (outs PDL_RangeOf<PDL_Type>:$result);
|
||||
let assemblyFormat = "$value attr-dict";
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "ArrayAttr":$type), [{
|
||||
build($_builder, $_state,
|
||||
pdl::RangeType::get($_builder.getType<pdl::TypeType>()), type);
|
||||
}]>
|
||||
];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -237,10 +237,12 @@ Value PatternLowering::getValueAt(Block *¤tBlock, Position *pos) {
|
|||
return val;
|
||||
|
||||
// Get the value for the parent position.
|
||||
Value parentVal = getValueAt(currentBlock, pos->getParent());
|
||||
Value parentVal;
|
||||
if (Position *parent = pos->getParent())
|
||||
parentVal = getValueAt(currentBlock, pos->getParent());
|
||||
|
||||
// TODO: Use a location from the position.
|
||||
Location loc = parentVal.getLoc();
|
||||
Location loc = parentVal ? parentVal.getLoc() : builder.getUnknownLoc();
|
||||
builder.setInsertionPointToEnd(currentBlock);
|
||||
Value value;
|
||||
switch (pos->getKind()) {
|
||||
|
@ -331,6 +333,22 @@ Value PatternLowering::getValueAt(Block *¤tBlock, Position *pos) {
|
|||
parentVal, resPos->getResultGroupNumber());
|
||||
break;
|
||||
}
|
||||
case Predicates::AttributeLiteralPos: {
|
||||
auto *attrPos = cast<AttributeLiteralPosition>(pos);
|
||||
value =
|
||||
builder.create<pdl_interp::CreateAttributeOp>(loc, attrPos->getValue());
|
||||
break;
|
||||
}
|
||||
case Predicates::TypeLiteralPos: {
|
||||
auto *typePos = cast<TypeLiteralPosition>(pos);
|
||||
Attribute rawTypeAttr = typePos->getValue();
|
||||
if (TypeAttr typeAttr = rawTypeAttr.dyn_cast<TypeAttr>())
|
||||
value = builder.create<pdl_interp::CreateTypeOp>(loc, typeAttr);
|
||||
else
|
||||
value = builder.create<pdl_interp::CreateTypesOp>(
|
||||
loc, rawTypeAttr.cast<ArrayAttr>());
|
||||
break;
|
||||
}
|
||||
default:
|
||||
llvm_unreachable("Generating unknown Position getter");
|
||||
break;
|
||||
|
@ -353,7 +371,7 @@ void PatternLowering::generate(BoolNode *boolNode, Block *¤tBlock,
|
|||
if (auto *equalToQuestion = dyn_cast<EqualToQuestion>(question)) {
|
||||
args = {getValueAt(currentBlock, equalToQuestion->getValue())};
|
||||
} else if (auto *cstQuestion = dyn_cast<ConstraintQuestion>(question)) {
|
||||
for (Position *position : std::get<1>(cstQuestion->getValue()))
|
||||
for (Position *position : cstQuestion->getArgs())
|
||||
args.push_back(getValueAt(currentBlock, position));
|
||||
}
|
||||
|
||||
|
@ -413,10 +431,10 @@ void PatternLowering::generate(BoolNode *boolNode, Block *¤tBlock,
|
|||
break;
|
||||
}
|
||||
case Predicates::ConstraintQuestion: {
|
||||
auto value = cast<ConstraintQuestion>(question)->getValue();
|
||||
auto *cstQuestion = cast<ConstraintQuestion>(question);
|
||||
builder.create<pdl_interp::ApplyConstraintOp>(
|
||||
loc, std::get<0>(value), args, std::get<2>(value).cast<ArrayAttr>(),
|
||||
success, failure);
|
||||
loc, cstQuestion->getName(), args, cstQuestion->getParams(), success,
|
||||
failure);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
|
|
|
@ -21,7 +21,7 @@ Position::~Position() {}
|
|||
unsigned Position::getOperationDepth() const {
|
||||
if (const auto *operationPos = dyn_cast<OperationPosition>(this))
|
||||
return operationPos->getDepth();
|
||||
return parent->getOperationDepth();
|
||||
return parent ? parent->getOperationDepth() : 0;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -50,6 +50,8 @@ enum Kind : unsigned {
|
|||
ResultPos,
|
||||
ResultGroupPos,
|
||||
TypePos,
|
||||
AttributeLiteralPos,
|
||||
TypeLiteralPos,
|
||||
|
||||
// Questions, ordered by dependency and decreasing priority.
|
||||
IsNotNullQuestion,
|
||||
|
@ -173,6 +175,16 @@ struct AttributePosition
|
|||
StringAttr getName() const { return key.second; }
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AttributeLiteralPosition
|
||||
|
||||
/// A position describing a literal attribute.
|
||||
struct AttributeLiteralPosition
|
||||
: public PredicateBase<AttributeLiteralPosition, Position, Attribute,
|
||||
Predicates::AttributeLiteralPos> {
|
||||
using PredicateBase::PredicateBase;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// OperandPosition
|
||||
|
||||
|
@ -317,6 +329,17 @@ struct TypePosition : public PredicateBase<TypePosition, Position, Position *,
|
|||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TypeLiteralPosition
|
||||
|
||||
/// A position describing a literal type or type range. The value is stored as
|
||||
/// either a TypeAttr, or an ArrayAttr of TypeAttr.
|
||||
struct TypeLiteralPosition
|
||||
: public PredicateBase<TypeLiteralPosition, Position, Attribute,
|
||||
Predicates::TypeLiteralPos> {
|
||||
using PredicateBase::PredicateBase;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Qualifiers
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -404,6 +427,17 @@ struct ConstraintQuestion
|
|||
Predicates::ConstraintQuestion> {
|
||||
using Base::Base;
|
||||
|
||||
/// Return the name of the constraint.
|
||||
StringRef getName() const { return std::get<0>(key); }
|
||||
|
||||
/// Return the arguments of the constraint.
|
||||
ArrayRef<Position *> getArgs() const { return std::get<1>(key); }
|
||||
|
||||
/// Return the constant parameters of the constraint.
|
||||
ArrayAttr getParams() const {
|
||||
return std::get<2>(key).dyn_cast_or_null<ArrayAttr>();
|
||||
}
|
||||
|
||||
/// Construct an instance with the given storage allocator.
|
||||
static ConstraintQuestion *construct(StorageUniquer::StorageAllocator &alloc,
|
||||
KeyTy key) {
|
||||
|
@ -461,12 +495,14 @@ public:
|
|||
PredicateUniquer() {
|
||||
// Register the types of Positions with the uniquer.
|
||||
registerParametricStorageType<AttributePosition>();
|
||||
registerParametricStorageType<AttributeLiteralPosition>();
|
||||
registerParametricStorageType<OperandPosition>();
|
||||
registerParametricStorageType<OperandGroupPosition>();
|
||||
registerParametricStorageType<OperationPosition>();
|
||||
registerParametricStorageType<ResultPosition>();
|
||||
registerParametricStorageType<ResultGroupPosition>();
|
||||
registerParametricStorageType<TypePosition>();
|
||||
registerParametricStorageType<TypeLiteralPosition>();
|
||||
|
||||
// Register the types of Questions with the uniquer.
|
||||
registerParametricStorageType<AttributeAnswer>();
|
||||
|
@ -527,6 +563,11 @@ public:
|
|||
return AttributePosition::get(uniquer, p, StringAttr::get(ctx, name));
|
||||
}
|
||||
|
||||
/// Returns an attribute position for the given attribute.
|
||||
Position *getAttributeLiteral(Attribute attr) {
|
||||
return AttributeLiteralPosition::get(uniquer, attr);
|
||||
}
|
||||
|
||||
/// Returns an operand position for an operand of the given operation.
|
||||
Position *getOperand(OperationPosition *p, unsigned operand) {
|
||||
return OperandPosition::get(uniquer, p, operand);
|
||||
|
@ -558,6 +599,12 @@ public:
|
|||
/// Returns a type position for the given entity.
|
||||
Position *getType(Position *p) { return TypePosition::get(uniquer, p); }
|
||||
|
||||
/// Returns a type position for the given type value. The value is stored
|
||||
/// as either a TypeAttr, or an ArrayAttr of TypeAttr.
|
||||
Position *getTypeLiteral(Attribute attr) {
|
||||
return TypeLiteralPosition::get(uniquer, attr);
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Qualifiers
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
|
|
@ -243,8 +243,18 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList,
|
|||
.Default([](auto *) { llvm_unreachable("unexpected position kind"); });
|
||||
}
|
||||
|
||||
/// Collect all of the predicates related to constraints within the given
|
||||
/// pattern operation.
|
||||
static void getAttributePredicates(pdl::AttributeOp op,
|
||||
std::vector<PositionalPredicate> &predList,
|
||||
PredicateBuilder &builder,
|
||||
DenseMap<Value, Position *> &inputs) {
|
||||
Position *&attrPos = inputs[op];
|
||||
if (attrPos)
|
||||
return;
|
||||
Attribute value = op.valueAttr();
|
||||
assert(value && "expected non-tree `pdl.attribute` to contain a value");
|
||||
attrPos = builder.getAttributeLiteral(value);
|
||||
}
|
||||
|
||||
static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op,
|
||||
std::vector<PositionalPredicate> &predList,
|
||||
PredicateBuilder &builder,
|
||||
|
@ -296,6 +306,19 @@ static void getResultPredicates(pdl::ResultsOp op,
|
|||
predList.emplace_back(resultPos, builder.getIsNotNull());
|
||||
}
|
||||
|
||||
static void getTypePredicates(Value typeValue,
|
||||
function_ref<Attribute()> typeAttrFn,
|
||||
PredicateBuilder &builder,
|
||||
DenseMap<Value, Position *> &inputs) {
|
||||
Position *&typePos = inputs[typeValue];
|
||||
if (typePos)
|
||||
return;
|
||||
Attribute typeAttr = typeAttrFn();
|
||||
assert(typeAttr &&
|
||||
"expected non-tree `pdl.type`/`pdl.types` to contain a value");
|
||||
typePos = builder.getTypeLiteral(typeAttr);
|
||||
}
|
||||
|
||||
/// Collect all of the predicates that cannot be determined via walking the
|
||||
/// tree.
|
||||
static void getNonTreePredicates(pdl::PatternOp pattern,
|
||||
|
@ -304,11 +327,22 @@ static void getNonTreePredicates(pdl::PatternOp pattern,
|
|||
DenseMap<Value, Position *> &inputs) {
|
||||
for (Operation &op : pattern.body().getOps()) {
|
||||
TypeSwitch<Operation *>(&op)
|
||||
.Case([&](pdl::AttributeOp attrOp) {
|
||||
getAttributePredicates(attrOp, predList, builder, inputs);
|
||||
})
|
||||
.Case<pdl::ApplyNativeConstraintOp>([&](auto constraintOp) {
|
||||
getConstraintPredicates(constraintOp, predList, builder, inputs);
|
||||
})
|
||||
.Case<pdl::ResultOp, pdl::ResultsOp>([&](auto resultOp) {
|
||||
getResultPredicates(resultOp, predList, builder, inputs);
|
||||
})
|
||||
.Case([&](pdl::TypeOp typeOp) {
|
||||
getTypePredicates(
|
||||
typeOp, [&] { return typeOp.typeAttr(); }, builder, inputs);
|
||||
})
|
||||
.Case([&](pdl::TypesOp typeOp) {
|
||||
getTypePredicates(
|
||||
typeOp, [&] { return typeOp.typesAttr(); }, builder, inputs);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
|
@ -114,12 +114,15 @@ static LogicalResult verify(AttributeOp op) {
|
|||
Value attrType = op.type();
|
||||
Optional<Attribute> attrValue = op.value();
|
||||
|
||||
if (!attrValue && isa<RewriteOp>(op->getParentOp()))
|
||||
return op.emitOpError("expected constant value when specified within a "
|
||||
"`pdl.rewrite`");
|
||||
if (attrValue && attrType)
|
||||
if (!attrValue) {
|
||||
if (isa<RewriteOp>(op->getParentOp()))
|
||||
return op.emitOpError("expected constant value when specified within a "
|
||||
"`pdl.rewrite`");
|
||||
return verifyHasBindingUse(op);
|
||||
}
|
||||
if (attrType)
|
||||
return op.emitOpError("expected only one of [`type`, `value`] to be set");
|
||||
return verifyHasBindingUse(op);
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -431,13 +434,21 @@ static LogicalResult verify(RewriteOp op) {
|
|||
// pdl::TypeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(TypeOp op) { return verifyHasBindingUse(op); }
|
||||
static LogicalResult verify(TypeOp op) {
|
||||
if (!op.typeAttr())
|
||||
return verifyHasBindingUse(op);
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// pdl::TypesOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(TypesOp op) { return verifyHasBindingUse(op); }
|
||||
static LogicalResult verify(TypesOp op) {
|
||||
if (!op.typesAttr())
|
||||
return verifyHasBindingUse(op);
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TableGen'd op method definitions
|
||||
|
|
|
@ -573,3 +573,42 @@ module @variadic_results_at {
|
|||
pdl.rewrite with "rewriter"(%root1, %root2 : !pdl.operation, !pdl.operation)
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: module @attribute_literal
|
||||
module @attribute_literal {
|
||||
// CHECK: func @matcher(%{{.*}}: !pdl.operation)
|
||||
// CHECK: %[[ATTR:.*]] = pdl_interp.create_attribute 10 : i64
|
||||
// CHECK: pdl_interp.apply_constraint "constraint"(%[[ATTR]] : !pdl.attribute)
|
||||
|
||||
// Check the correct lowering of an attribute that hasn't been bound.
|
||||
pdl.pattern : benefit(1) {
|
||||
%attr = pdl.attribute 10
|
||||
pdl.apply_native_constraint "constraint"(%attr: !pdl.attribute)
|
||||
|
||||
%root = pdl.operation
|
||||
pdl.rewrite %root with "rewriter"
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: module @type_literal
|
||||
module @type_literal {
|
||||
// CHECK: func @matcher(%{{.*}}: !pdl.operation)
|
||||
// CHECK: %[[TYPE:.*]] = pdl_interp.create_type i32
|
||||
// CHECK: %[[TYPES:.*]] = pdl_interp.create_types [i32, i64]
|
||||
// CHECK: pdl_interp.apply_constraint "constraint"(%[[TYPE]], %[[TYPES]] : !pdl.type, !pdl.range<type>)
|
||||
|
||||
// Check the correct lowering of a type that hasn't been bound.
|
||||
pdl.pattern : benefit(1) {
|
||||
%type = pdl.type : i32
|
||||
%types = pdl.types : [i32, i64]
|
||||
pdl.apply_native_constraint "constraint"(%type, %types: !pdl.type, !pdl.range<type>)
|
||||
|
||||
%root = pdl.operation
|
||||
pdl.rewrite %root with "rewriter"
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue