forked from OSchip/llvm-project
[mlir] Make ConversionTarget dynamic legality callbacks composable
* Change callback signature `bool(Operation *)` -> `Optional<bool>(Operation *)` * addDynamicallyLegalOp add callback to the chain * If callback returned empty `Optional` next callback in chain will be called Differential Revision: https://reviews.llvm.org/D110487
This commit is contained in:
parent
649cc160e3
commit
c6828e0cea
|
@ -661,7 +661,7 @@ public:
|
|||
|
||||
/// The signature of the callback used to determine if an operation is
|
||||
/// dynamically legal on the target.
|
||||
using DynamicLegalityCallbackFn = std::function<bool(Operation *)>;
|
||||
using DynamicLegalityCallbackFn = std::function<Optional<bool>(Operation *)>;
|
||||
|
||||
ConversionTarget(MLIRContext &ctx) : ctx(ctx) {}
|
||||
virtual ~ConversionTarget() = default;
|
||||
|
@ -827,10 +827,10 @@ private:
|
|||
/// The set of information that configures the legalization of an operation.
|
||||
struct LegalizationInfo {
|
||||
/// The legality action this operation was given.
|
||||
LegalizationAction action;
|
||||
LegalizationAction action = LegalizationAction::Illegal;
|
||||
|
||||
/// If some legal instances of this operation may also be recursively legal.
|
||||
bool isRecursivelyLegal;
|
||||
bool isRecursivelyLegal = false;
|
||||
|
||||
/// The legality callback if this operation is dynamically legal.
|
||||
DynamicLegalityCallbackFn legalityFn;
|
||||
|
|
|
@ -2681,7 +2681,7 @@ void mlir::populateFuncOpTypeConversionPattern(RewritePatternSet &patterns,
|
|||
/// Register a legality action for the given operation.
|
||||
void ConversionTarget::setOpAction(OperationName op,
|
||||
LegalizationAction action) {
|
||||
legalOperations[op] = {action, /*isRecursivelyLegal=*/false, nullptr};
|
||||
legalOperations[op].action = action;
|
||||
}
|
||||
|
||||
/// Register a legality action for the given dialects.
|
||||
|
@ -2710,8 +2710,11 @@ auto ConversionTarget::isLegal(Operation *op) const
|
|||
// Returns true if this operation instance is known to be legal.
|
||||
auto isOpLegal = [&] {
|
||||
// Handle dynamic legality either with the provided legality function.
|
||||
if (info->action == LegalizationAction::Dynamic)
|
||||
return info->legalityFn(op);
|
||||
if (info->action == LegalizationAction::Dynamic) {
|
||||
Optional<bool> result = info->legalityFn(op);
|
||||
if (result)
|
||||
return *result;
|
||||
}
|
||||
|
||||
// Otherwise, the operation is only legal if it was marked 'Legal'.
|
||||
return info->action == LegalizationAction::Legal;
|
||||
|
@ -2723,14 +2726,32 @@ auto ConversionTarget::isLegal(Operation *op) const
|
|||
LegalOpDetails legalityDetails;
|
||||
if (info->isRecursivelyLegal) {
|
||||
auto legalityFnIt = opRecursiveLegalityFns.find(op->getName());
|
||||
if (legalityFnIt != opRecursiveLegalityFns.end())
|
||||
legalityDetails.isRecursivelyLegal = legalityFnIt->second(op);
|
||||
else
|
||||
if (legalityFnIt != opRecursiveLegalityFns.end()) {
|
||||
legalityDetails.isRecursivelyLegal =
|
||||
legalityFnIt->second(op).getValueOr(true);
|
||||
} else {
|
||||
legalityDetails.isRecursivelyLegal = true;
|
||||
}
|
||||
}
|
||||
return legalityDetails;
|
||||
}
|
||||
|
||||
static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(
|
||||
ConversionTarget::DynamicLegalityCallbackFn oldCallback,
|
||||
ConversionTarget::DynamicLegalityCallbackFn newCallback) {
|
||||
if (!oldCallback)
|
||||
return newCallback;
|
||||
|
||||
auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)](
|
||||
Operation *op) -> Optional<bool> {
|
||||
if (Optional<bool> result = newCl(op))
|
||||
return *result;
|
||||
|
||||
return oldCl(op);
|
||||
};
|
||||
return chain;
|
||||
}
|
||||
|
||||
/// Set the dynamic legality callback for the given operation.
|
||||
void ConversionTarget::setLegalityCallback(
|
||||
OperationName name, const DynamicLegalityCallbackFn &callback) {
|
||||
|
@ -2739,7 +2760,8 @@ void ConversionTarget::setLegalityCallback(
|
|||
assert(infoIt != legalOperations.end() &&
|
||||
infoIt->second.action == LegalizationAction::Dynamic &&
|
||||
"expected operation to already be marked as dynamically legal");
|
||||
infoIt->second.legalityFn = callback;
|
||||
infoIt->second.legalityFn =
|
||||
composeLegalityCallbacks(std::move(infoIt->second.legalityFn), callback);
|
||||
}
|
||||
|
||||
/// Set the recursive legality callback for the given operation and mark the
|
||||
|
@ -2752,7 +2774,8 @@ void ConversionTarget::markOpRecursivelyLegal(
|
|||
"expected operation to already be marked as legal");
|
||||
infoIt->second.isRecursivelyLegal = true;
|
||||
if (callback)
|
||||
opRecursiveLegalityFns[name] = callback;
|
||||
opRecursiveLegalityFns[name] = composeLegalityCallbacks(
|
||||
std::move(opRecursiveLegalityFns[name]), callback);
|
||||
else
|
||||
opRecursiveLegalityFns.erase(name);
|
||||
}
|
||||
|
@ -2762,14 +2785,15 @@ void ConversionTarget::setLegalityCallback(
|
|||
ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) {
|
||||
assert(callback && "expected valid legality callback");
|
||||
for (StringRef dialect : dialects)
|
||||
dialectLegalityFns[dialect] = callback;
|
||||
dialectLegalityFns[dialect] = composeLegalityCallbacks(
|
||||
std::move(dialectLegalityFns[dialect]), callback);
|
||||
}
|
||||
|
||||
/// Set the dynamic legality callback for the unknown ops.
|
||||
void ConversionTarget::setLegalityCallback(
|
||||
const DynamicLegalityCallbackFn &callback) {
|
||||
assert(callback && "expected valid legality callback");
|
||||
unknownLegalityFn = callback;
|
||||
unknownLegalityFn = composeLegalityCallbacks(unknownLegalityFn, callback);
|
||||
}
|
||||
|
||||
/// Get the legalization information for the given operation.
|
||||
|
|
|
@ -12,3 +12,4 @@ add_subdirectory(IR)
|
|||
add_subdirectory(Pass)
|
||||
add_subdirectory(Rewrite)
|
||||
add_subdirectory(TableGen)
|
||||
add_subdirectory(Transforms)
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
add_mlir_unittest(MLIRTransformsTests
|
||||
DialectConversion.cpp
|
||||
)
|
||||
target_link_libraries(MLIRTransformsTests
|
||||
PRIVATE
|
||||
MLIRTransforms)
|
|
@ -0,0 +1,90 @@
|
|||
//===- DialectConversion.cpp - Dialect conversion unit tests --------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
static Operation *createOp(MLIRContext *context) {
|
||||
context->allowUnregisteredDialects();
|
||||
return Operation::create(UnknownLoc::get(context),
|
||||
OperationName("foo.bar", context), llvm::None,
|
||||
llvm::None, llvm::None, llvm::None, 0);
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct DummyOp {
|
||||
static StringRef getOperationName() { return "foo.bar"; }
|
||||
};
|
||||
|
||||
TEST(DialectConversionTest, DynamicallyLegalOpCallbackOrder) {
|
||||
MLIRContext context;
|
||||
ConversionTarget target(context);
|
||||
|
||||
int index = 0;
|
||||
int callbackCalled1 = 0;
|
||||
target.addDynamicallyLegalOp<DummyOp>([&](Operation *) {
|
||||
callbackCalled1 = ++index;
|
||||
return true;
|
||||
});
|
||||
|
||||
int callbackCalled2 = 0;
|
||||
target.addDynamicallyLegalOp<DummyOp>([&](Operation *) -> Optional<bool> {
|
||||
callbackCalled2 = ++index;
|
||||
return llvm::None;
|
||||
});
|
||||
|
||||
auto *op = createOp(&context);
|
||||
EXPECT_TRUE(target.isLegal(op));
|
||||
EXPECT_EQ(2, callbackCalled1);
|
||||
EXPECT_EQ(1, callbackCalled2);
|
||||
op->destroy();
|
||||
}
|
||||
|
||||
TEST(DialectConversionTest, DynamicallyLegalOpCallbackSkip) {
|
||||
MLIRContext context;
|
||||
ConversionTarget target(context);
|
||||
|
||||
int index = 0;
|
||||
int callbackCalled = 0;
|
||||
target.addDynamicallyLegalOp<DummyOp>([&](Operation *) -> Optional<bool> {
|
||||
callbackCalled = ++index;
|
||||
return llvm::None;
|
||||
});
|
||||
|
||||
auto *op = createOp(&context);
|
||||
EXPECT_FALSE(target.isLegal(op));
|
||||
EXPECT_EQ(1, callbackCalled);
|
||||
op->destroy();
|
||||
}
|
||||
|
||||
TEST(DialectConversionTest, DynamicallyLegalUnknownOpCallbackOrder) {
|
||||
MLIRContext context;
|
||||
ConversionTarget target(context);
|
||||
|
||||
int index = 0;
|
||||
int callbackCalled1 = 0;
|
||||
target.markUnknownOpDynamicallyLegal([&](Operation *) {
|
||||
callbackCalled1 = ++index;
|
||||
return true;
|
||||
});
|
||||
|
||||
int callbackCalled2 = 0;
|
||||
target.markUnknownOpDynamicallyLegal([&](Operation *) -> Optional<bool> {
|
||||
callbackCalled2 = ++index;
|
||||
return llvm::None;
|
||||
});
|
||||
|
||||
auto *op = createOp(&context);
|
||||
EXPECT_TRUE(target.isLegal(op));
|
||||
EXPECT_EQ(2, callbackCalled1);
|
||||
EXPECT_EQ(1, callbackCalled2);
|
||||
op->destroy();
|
||||
}
|
||||
} // namespace
|
Loading…
Reference in New Issue