[mlir] Make ConversionTarget dynamic legality callbacks composable

* Change callback signature `bool(Operation *)` -> `Optional<bool>(Operation *)`
* addDynamicallyLegalOp add callback to the chain
* If callback returned empty `Optional` next callback in chain will be called

Differential Revision: https://reviews.llvm.org/D110487
This commit is contained in:
Caitlyn Cano 2021-07-01 20:41:51 +00:00 committed by Butygin
parent 649cc160e3
commit c6828e0cea
5 changed files with 134 additions and 13 deletions

View File

@ -661,7 +661,7 @@ public:
/// The signature of the callback used to determine if an operation is
/// dynamically legal on the target.
using DynamicLegalityCallbackFn = std::function<bool(Operation *)>;
using DynamicLegalityCallbackFn = std::function<Optional<bool>(Operation *)>;
ConversionTarget(MLIRContext &ctx) : ctx(ctx) {}
virtual ~ConversionTarget() = default;
@ -827,10 +827,10 @@ private:
/// The set of information that configures the legalization of an operation.
struct LegalizationInfo {
/// The legality action this operation was given.
LegalizationAction action;
LegalizationAction action = LegalizationAction::Illegal;
/// If some legal instances of this operation may also be recursively legal.
bool isRecursivelyLegal;
bool isRecursivelyLegal = false;
/// The legality callback if this operation is dynamically legal.
DynamicLegalityCallbackFn legalityFn;

View File

@ -2681,7 +2681,7 @@ void mlir::populateFuncOpTypeConversionPattern(RewritePatternSet &patterns,
/// Register a legality action for the given operation.
void ConversionTarget::setOpAction(OperationName op,
LegalizationAction action) {
legalOperations[op] = {action, /*isRecursivelyLegal=*/false, nullptr};
legalOperations[op].action = action;
}
/// Register a legality action for the given dialects.
@ -2710,8 +2710,11 @@ auto ConversionTarget::isLegal(Operation *op) const
// Returns true if this operation instance is known to be legal.
auto isOpLegal = [&] {
// Handle dynamic legality either with the provided legality function.
if (info->action == LegalizationAction::Dynamic)
return info->legalityFn(op);
if (info->action == LegalizationAction::Dynamic) {
Optional<bool> result = info->legalityFn(op);
if (result)
return *result;
}
// Otherwise, the operation is only legal if it was marked 'Legal'.
return info->action == LegalizationAction::Legal;
@ -2723,14 +2726,32 @@ auto ConversionTarget::isLegal(Operation *op) const
LegalOpDetails legalityDetails;
if (info->isRecursivelyLegal) {
auto legalityFnIt = opRecursiveLegalityFns.find(op->getName());
if (legalityFnIt != opRecursiveLegalityFns.end())
legalityDetails.isRecursivelyLegal = legalityFnIt->second(op);
else
if (legalityFnIt != opRecursiveLegalityFns.end()) {
legalityDetails.isRecursivelyLegal =
legalityFnIt->second(op).getValueOr(true);
} else {
legalityDetails.isRecursivelyLegal = true;
}
}
return legalityDetails;
}
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(
ConversionTarget::DynamicLegalityCallbackFn oldCallback,
ConversionTarget::DynamicLegalityCallbackFn newCallback) {
if (!oldCallback)
return newCallback;
auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
Operation *op) -> Optional<bool> {
if (Optional<bool> result = newCl(op))
return *result;
return oldCl(op);
};
return chain;
}
/// Set the dynamic legality callback for the given operation.
void ConversionTarget::setLegalityCallback(
OperationName name, const DynamicLegalityCallbackFn &callback) {
@ -2739,7 +2760,8 @@ void ConversionTarget::setLegalityCallback(
assert(infoIt != legalOperations.end() &&
infoIt->second.action == LegalizationAction::Dynamic &&
"expected operation to already be marked as dynamically legal");
infoIt->second.legalityFn = callback;
infoIt->second.legalityFn =
composeLegalityCallbacks(std::move(infoIt->second.legalityFn), callback);
}
/// Set the recursive legality callback for the given operation and mark the
@ -2752,7 +2774,8 @@ void ConversionTarget::markOpRecursivelyLegal(
"expected operation to already be marked as legal");
infoIt->second.isRecursivelyLegal = true;
if (callback)
opRecursiveLegalityFns[name] = callback;
opRecursiveLegalityFns[name] = composeLegalityCallbacks(
std::move(opRecursiveLegalityFns[name]), callback);
else
opRecursiveLegalityFns.erase(name);
}
@ -2762,14 +2785,15 @@ void ConversionTarget::setLegalityCallback(
ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) {
assert(callback && "expected valid legality callback");
for (StringRef dialect : dialects)
dialectLegalityFns[dialect] = callback;
dialectLegalityFns[dialect] = composeLegalityCallbacks(
std::move(dialectLegalityFns[dialect]), callback);
}
/// Set the dynamic legality callback for the unknown ops.
void ConversionTarget::setLegalityCallback(
const DynamicLegalityCallbackFn &callback) {
assert(callback && "expected valid legality callback");
unknownLegalityFn = callback;
unknownLegalityFn = composeLegalityCallbacks(unknownLegalityFn, callback);
}
/// Get the legalization information for the given operation.

View File

@ -12,3 +12,4 @@ add_subdirectory(IR)
add_subdirectory(Pass)
add_subdirectory(Rewrite)
add_subdirectory(TableGen)
add_subdirectory(Transforms)

View File

@ -0,0 +1,6 @@
add_mlir_unittest(MLIRTransformsTests
DialectConversion.cpp
)
target_link_libraries(MLIRTransformsTests
PRIVATE
MLIRTransforms)

View File

@ -0,0 +1,90 @@
//===- DialectConversion.cpp - Dialect conversion unit tests --------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/DialectConversion.h"
#include "gtest/gtest.h"
using namespace mlir;
static Operation *createOp(MLIRContext *context) {
context->allowUnregisteredDialects();
return Operation::create(UnknownLoc::get(context),
OperationName("foo.bar", context), llvm::None,
llvm::None, llvm::None, llvm::None, 0);
}
namespace {
struct DummyOp {
static StringRef getOperationName() { return "foo.bar"; }
};
TEST(DialectConversionTest, DynamicallyLegalOpCallbackOrder) {
MLIRContext context;
ConversionTarget target(context);
int index = 0;
int callbackCalled1 = 0;
target.addDynamicallyLegalOp<DummyOp>([&](Operation *) {
callbackCalled1 = ++index;
return true;
});
int callbackCalled2 = 0;
target.addDynamicallyLegalOp<DummyOp>([&](Operation *) -> Optional<bool> {
callbackCalled2 = ++index;
return llvm::None;
});
auto *op = createOp(&context);
EXPECT_TRUE(target.isLegal(op));
EXPECT_EQ(2, callbackCalled1);
EXPECT_EQ(1, callbackCalled2);
op->destroy();
}
TEST(DialectConversionTest, DynamicallyLegalOpCallbackSkip) {
MLIRContext context;
ConversionTarget target(context);
int index = 0;
int callbackCalled = 0;
target.addDynamicallyLegalOp<DummyOp>([&](Operation *) -> Optional<bool> {
callbackCalled = ++index;
return llvm::None;
});
auto *op = createOp(&context);
EXPECT_FALSE(target.isLegal(op));
EXPECT_EQ(1, callbackCalled);
op->destroy();
}
TEST(DialectConversionTest, DynamicallyLegalUnknownOpCallbackOrder) {
MLIRContext context;
ConversionTarget target(context);
int index = 0;
int callbackCalled1 = 0;
target.markUnknownOpDynamicallyLegal([&](Operation *) {
callbackCalled1 = ++index;
return true;
});
int callbackCalled2 = 0;
target.markUnknownOpDynamicallyLegal([&](Operation *) -> Optional<bool> {
callbackCalled2 = ++index;
return llvm::None;
});
auto *op = createOp(&context);
EXPECT_TRUE(target.isLegal(op));
EXPECT_EQ(2, callbackCalled1);
EXPECT_EQ(1, callbackCalled2);
op->destroy();
}
} // namespace