diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md index e6b652f21913..d0643947ed20 100644 --- a/mlir/docs/DialectConversion.md +++ b/mlir/docs/DialectConversion.md @@ -100,6 +100,11 @@ struct MyTarget : public ConversionTarget { /// callback. addDynamicallyLegalOp([](ReturnOp op) { ... }); + /// Treat unknown operations, i.e. those without a legalization action + /// directly set, as dynamically legal. + markUnknownOpDynamicallyLegal(); + markUnknownOpDynamicallyLegal([](Operation *op) { ... }); + //-------------------------------------------------------------------------- // Marking an operation as illegal. diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index aadb59226c64..cd148f2fd2ea 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -416,7 +416,8 @@ public: /// dynamically legal on the target. using DynamicLegalityCallbackFn = std::function; - ConversionTarget(MLIRContext &ctx) : ctx(ctx) {} + ConversionTarget(MLIRContext &ctx) + : unknownOpsDynamicallyLegal(false), ctx(ctx) {} virtual ~ConversionTarget() = default; //===--------------------------------------------------------------------===// @@ -532,6 +533,16 @@ public: setLegalityCallback(dialectNames, *callback); } + /// Register unknown operations as dynamically legal. For operations(and + /// dialects) that do not have a set legalization action, treat them as + /// dynamically legal and invoke the given callback if valid or + /// 'isDynamicallyLegal'. + void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn) { + unknownOpsDynamicallyLegal = true; + unknownLegalityFn = fn; + } + void markUnknownOpDynamicallyLegal() { unknownOpsDynamicallyLegal = true; } + /// Register the operations of the given dialects as illegal, i.e. /// operations of this dialect are not supported by the target. template @@ -585,6 +596,9 @@ private: /// If some legal instances of this operation may also be recursively legal. bool isRecursivelyLegal; + + /// The legality callback if this operation is dynamically legal. + Optional legalityFn; }; /// Get the legalization information for the given operation. @@ -594,9 +608,6 @@ private: /// information. llvm::MapVector legalOperations; - /// A set of dynamic legality callbacks for given operation names. - DenseMap opLegalityFns; - /// A set of legality callbacks for given operation names that are used to /// check if an operation instance is recursively legal. DenseMap opRecursiveLegalityFns; @@ -608,6 +619,13 @@ private: /// A set of dynamic legality callbacks for given dialect names. llvm::StringMap dialectLegalityFns; + /// An optional legality callback for unknown operations. + Optional unknownLegalityFn; + + /// Flag indicating if unknown operations should be treated as dynamically + /// legal. + bool unknownOpsDynamicallyLegal; + /// The current context this target applies to. MLIRContext &ctx; }; diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index e2cd12e0a7a4..c6e7f9b88b46 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -1704,19 +1704,11 @@ auto ConversionTarget::isLegal(Operation *op) const // Returns true if this operation instance is known to be legal. auto isOpLegal = [&] { - // Handle dynamic legality. - if (info->action == LegalizationAction::Dynamic) { - // Check for callbacks on the operation or dialect. - auto opFn = opLegalityFns.find(op->getName()); - if (opFn != opLegalityFns.end()) - return opFn->second(op); - auto dialectFn = dialectLegalityFns.find(op->getName().getDialect()); - if (dialectFn != dialectLegalityFns.end()) - return dialectFn->second(op); - - // Otherwise, invoke the hook on the derived instance. - return isDynamicallyLegal(op); - } + // Handle dynamic legality either with the provided legality function, or + // the default hook on the derived instance. + if (info->action == LegalizationAction::Dynamic) + return info->legalityFn ? (*info->legalityFn)(op) + : isDynamicallyLegal(op); // Otherwise, the operation is only legal if it was marked 'Legal'. return info->action == LegalizationAction::Legal; @@ -1726,7 +1718,6 @@ auto ConversionTarget::isLegal(Operation *op) const // This operation is legal, compute any additional legality information. LegalOpDetails legalityDetails; - if (info->isRecursivelyLegal) { auto legalityFnIt = opRecursiveLegalityFns.find(op->getName()); if (legalityFnIt != opRecursiveLegalityFns.end()) @@ -1741,7 +1732,11 @@ auto ConversionTarget::isLegal(Operation *op) const void ConversionTarget::setLegalityCallback( OperationName name, const DynamicLegalityCallbackFn &callback) { assert(callback && "expected valid legality callback"); - opLegalityFns[name] = callback; + auto infoIt = legalOperations.find(name); + assert(infoIt != legalOperations.end() && + infoIt->second.action == LegalizationAction::Dynamic && + "expected operation to already be marked as dynamically legal"); + infoIt->second.legalityFn = callback; } /// Set the recursive legality callback for the given operation and mark the @@ -1774,10 +1769,20 @@ auto ConversionTarget::getOpInfo(OperationName op) const auto it = legalOperations.find(op); if (it != legalOperations.end()) return it->second; - // Otherwise, default to checking on the parent dialect. + // Check for info for the parent dialect. auto dialectIt = legalDialects.find(op.getDialect()); - if (dialectIt != legalDialects.end()) - return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false}; + if (dialectIt != legalDialects.end()) { + Optional callback; + auto dialectFn = dialectLegalityFns.find(op.getDialect()); + if (dialectFn != dialectLegalityFns.end()) + callback = dialectFn->second; + return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false, + callback}; + } + // Otherwise, check if we mark unknown operations as dynamic. + if (unknownOpsDynamicallyLegal) + return LegalizationInfo{LegalizationAction::Dynamic, + /*isRecursivelyLegal=*/false, unknownLegalityFn}; return llvm::None; } diff --git a/mlir/test/Transforms/test-legalizer-full.mlir b/mlir/test/Transforms/test-legalizer-full.mlir index d0fc4c9bc898..6bbda4aad2ba 100644 --- a/mlir/test/Transforms/test-legalizer-full.mlir +++ b/mlir/test/Transforms/test-legalizer-full.mlir @@ -58,3 +58,14 @@ func @test_undo_region_clone() { %ignored = "test.illegal_op_f"() : () -> (i32) "test.return"() : () -> () } + +// ----- + +// Test that unknown operations can be dynamically legal. +func @test_unknown_dynamically_legal() { + "foo.unknown_op"() {test.dynamically_legal} : () -> () + + // expected-error@+1 {{failed to legalize operation 'foo.unknown_op'}} + "foo.unknown_op"() {} : () -> () + "test.return"() : () -> () +} diff --git a/mlir/test/lib/TestDialect/TestPatterns.cpp b/mlir/test/lib/TestDialect/TestPatterns.cpp index d9777487986b..d34181c45b18 100644 --- a/mlir/test/lib/TestDialect/TestPatterns.cpp +++ b/mlir/test/lib/TestDialect/TestPatterns.cpp @@ -399,6 +399,11 @@ struct TestLegalizePatternDriver // Handle a full conversion. if (mode == ConversionMode::Full) { + // Check support for marking unknown operations as dynamically legal. + target.markUnknownOpDynamicallyLegal([](Operation *op) { + return (bool)op->getAttrOfType("test.dynamically_legal"); + }); + (void)applyFullConversion(getModule(), target, patterns, &converter); return; }