[mlir] Add support for marking 'unknown' operations as dynamically legal.

Summary: This allows for providing a default "catchall" legality check that is not dependent on specific operations or dialects. For example, this can be useful to check legality based on the specific types of operation operands or results.

Differential Revision: https://reviews.llvm.org/D73379
This commit is contained in:
River Riddle 2020-01-27 19:04:55 -08:00
parent 49532137d0
commit ce674b131b
5 changed files with 66 additions and 22 deletions

View File

@ -100,6 +100,11 @@ struct MyTarget : public ConversionTarget {
/// callback.
addDynamicallyLegalOp<ReturnOp>([](ReturnOp op) { ... });
/// Treat unknown operations, i.e. those without a legalization action
/// directly set, as dynamically legal.
markUnknownOpDynamicallyLegal();
markUnknownOpDynamicallyLegal([](Operation *op) { ... });
//--------------------------------------------------------------------------
// Marking an operation as illegal.

View File

@ -416,7 +416,8 @@ public:
/// dynamically legal on the target.
using DynamicLegalityCallbackFn = std::function<bool(Operation *)>;
ConversionTarget(MLIRContext &ctx) : ctx(ctx) {}
ConversionTarget(MLIRContext &ctx)
: unknownOpsDynamicallyLegal(false), ctx(ctx) {}
virtual ~ConversionTarget() = default;
//===--------------------------------------------------------------------===//
@ -532,6 +533,16 @@ public:
setLegalityCallback(dialectNames, *callback);
}
/// Register unknown operations as dynamically legal. For operations(and
/// dialects) that do not have a set legalization action, treat them as
/// dynamically legal and invoke the given callback if valid or
/// 'isDynamicallyLegal'.
void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn) {
unknownOpsDynamicallyLegal = true;
unknownLegalityFn = fn;
}
void markUnknownOpDynamicallyLegal() { unknownOpsDynamicallyLegal = true; }
/// Register the operations of the given dialects as illegal, i.e.
/// operations of this dialect are not supported by the target.
template <typename... Names>
@ -585,6 +596,9 @@ private:
/// If some legal instances of this operation may also be recursively legal.
bool isRecursivelyLegal;
/// The legality callback if this operation is dynamically legal.
Optional<DynamicLegalityCallbackFn> legalityFn;
};
/// Get the legalization information for the given operation.
@ -594,9 +608,6 @@ private:
/// information.
llvm::MapVector<OperationName, LegalizationInfo> legalOperations;
/// A set of dynamic legality callbacks for given operation names.
DenseMap<OperationName, DynamicLegalityCallbackFn> opLegalityFns;
/// A set of legality callbacks for given operation names that are used to
/// check if an operation instance is recursively legal.
DenseMap<OperationName, DynamicLegalityCallbackFn> opRecursiveLegalityFns;
@ -608,6 +619,13 @@ private:
/// A set of dynamic legality callbacks for given dialect names.
llvm::StringMap<DynamicLegalityCallbackFn> dialectLegalityFns;
/// An optional legality callback for unknown operations.
Optional<DynamicLegalityCallbackFn> unknownLegalityFn;
/// Flag indicating if unknown operations should be treated as dynamically
/// legal.
bool unknownOpsDynamicallyLegal;
/// The current context this target applies to.
MLIRContext &ctx;
};

View File

@ -1704,19 +1704,11 @@ auto ConversionTarget::isLegal(Operation *op) const
// Returns true if this operation instance is known to be legal.
auto isOpLegal = [&] {
// Handle dynamic legality.
if (info->action == LegalizationAction::Dynamic) {
// Check for callbacks on the operation or dialect.
auto opFn = opLegalityFns.find(op->getName());
if (opFn != opLegalityFns.end())
return opFn->second(op);
auto dialectFn = dialectLegalityFns.find(op->getName().getDialect());
if (dialectFn != dialectLegalityFns.end())
return dialectFn->second(op);
// Otherwise, invoke the hook on the derived instance.
return isDynamicallyLegal(op);
}
// Handle dynamic legality either with the provided legality function, or
// the default hook on the derived instance.
if (info->action == LegalizationAction::Dynamic)
return info->legalityFn ? (*info->legalityFn)(op)
: isDynamicallyLegal(op);
// Otherwise, the operation is only legal if it was marked 'Legal'.
return info->action == LegalizationAction::Legal;
@ -1726,7 +1718,6 @@ auto ConversionTarget::isLegal(Operation *op) const
// This operation is legal, compute any additional legality information.
LegalOpDetails legalityDetails;
if (info->isRecursivelyLegal) {
auto legalityFnIt = opRecursiveLegalityFns.find(op->getName());
if (legalityFnIt != opRecursiveLegalityFns.end())
@ -1741,7 +1732,11 @@ auto ConversionTarget::isLegal(Operation *op) const
void ConversionTarget::setLegalityCallback(
OperationName name, const DynamicLegalityCallbackFn &callback) {
assert(callback && "expected valid legality callback");
opLegalityFns[name] = callback;
auto infoIt = legalOperations.find(name);
assert(infoIt != legalOperations.end() &&
infoIt->second.action == LegalizationAction::Dynamic &&
"expected operation to already be marked as dynamically legal");
infoIt->second.legalityFn = callback;
}
/// Set the recursive legality callback for the given operation and mark the
@ -1774,10 +1769,20 @@ auto ConversionTarget::getOpInfo(OperationName op) const
auto it = legalOperations.find(op);
if (it != legalOperations.end())
return it->second;
// Otherwise, default to checking on the parent dialect.
// Check for info for the parent dialect.
auto dialectIt = legalDialects.find(op.getDialect());
if (dialectIt != legalDialects.end())
return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false};
if (dialectIt != legalDialects.end()) {
Optional<DynamicLegalityCallbackFn> callback;
auto dialectFn = dialectLegalityFns.find(op.getDialect());
if (dialectFn != dialectLegalityFns.end())
callback = dialectFn->second;
return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false,
callback};
}
// Otherwise, check if we mark unknown operations as dynamic.
if (unknownOpsDynamicallyLegal)
return LegalizationInfo{LegalizationAction::Dynamic,
/*isRecursivelyLegal=*/false, unknownLegalityFn};
return llvm::None;
}

View File

@ -58,3 +58,14 @@ func @test_undo_region_clone() {
%ignored = "test.illegal_op_f"() : () -> (i32)
"test.return"() : () -> ()
}
// -----
// Test that unknown operations can be dynamically legal.
func @test_unknown_dynamically_legal() {
"foo.unknown_op"() {test.dynamically_legal} : () -> ()
// expected-error@+1 {{failed to legalize operation 'foo.unknown_op'}}
"foo.unknown_op"() {} : () -> ()
"test.return"() : () -> ()
}

View File

@ -399,6 +399,11 @@ struct TestLegalizePatternDriver
// Handle a full conversion.
if (mode == ConversionMode::Full) {
// Check support for marking unknown operations as dynamically legal.
target.markUnknownOpDynamicallyLegal([](Operation *op) {
return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal");
});
(void)applyFullConversion(getModule(), target, patterns, &converter);
return;
}