[mlir] ConversionTarget legality callbacks refactoring

* Get rid of Optional<std::function> as std::function already have a null state
* Add private setLegalityCallback function to set legality callback for unknown ops
* Get rid of unknownOpsDynamicallyLegal flag, use unknownLegalityFn state insted. This causes behavior change when user first calls markUnknownOpDynamicallyLegal with callback and then without but I am not sure is the original behavior was really a 'feature', or just oversignt in the original implementation.

Differential Revision: https://reviews.llvm.org/D105496
This commit is contained in:
Butygin 2021-07-06 19:11:16 +03:00
parent 05ae303555
commit b7a4649899
2 changed files with 23 additions and 25 deletions

View File

@ -621,8 +621,7 @@ public:
/// dynamically legal on the target.
using DynamicLegalityCallbackFn = std::function<bool(Operation *)>;
ConversionTarget(MLIRContext &ctx)
: unknownOpsDynamicallyLegal(false), ctx(ctx) {}
ConversionTarget(MLIRContext &ctx) : ctx(ctx) {}
virtual ~ConversionTarget() = default;
//===--------------------------------------------------------------------===//
@ -739,18 +738,11 @@ public:
setDialectAction(dialectNames, LegalizationAction::Dynamic);
}
template <typename... Args>
void addDynamicallyLegalDialect(
Optional<DynamicLegalityCallbackFn> callback = llvm::None) {
void addDynamicallyLegalDialect(DynamicLegalityCallbackFn callback = {}) {
SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
setDialectAction(dialectNames, LegalizationAction::Dynamic);
if (callback)
setLegalityCallback(dialectNames, *callback);
}
template <typename... Args>
void addDynamicallyLegalDialect(DynamicLegalityCallbackFn callback) {
SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
setDialectAction(dialectNames, LegalizationAction::Dynamic);
setLegalityCallback(dialectNames, callback);
setLegalityCallback(dialectNames, callback);
}
/// Register unknown operations as dynamically legal. For operations(and
@ -758,10 +750,11 @@ public:
/// dynamically legal and invoke the given callback if valid or
/// 'isDynamicallyLegal'.
void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn) {
unknownOpsDynamicallyLegal = true;
unknownLegalityFn = fn;
setLegalityCallback(fn);
}
void markUnknownOpDynamicallyLegal() {
setLegalityCallback([](Operation *) { return true; });
}
void markUnknownOpDynamicallyLegal() { unknownOpsDynamicallyLegal = true; }
/// Register the operations of the given dialects as illegal, i.e.
/// operations of this dialect are not supported by the target.
@ -805,6 +798,9 @@ private:
void setLegalityCallback(ArrayRef<StringRef> dialects,
const DynamicLegalityCallbackFn &callback);
/// Set the dynamic legality callback for the unknown ops.
void setLegalityCallback(const DynamicLegalityCallbackFn &callback);
/// Set the recursive legality callback for the given operation and mark the
/// operation as recursively legal.
void markOpRecursivelyLegal(OperationName name,
@ -819,7 +815,7 @@ private:
bool isRecursivelyLegal;
/// The legality callback if this operation is dynamically legal.
Optional<DynamicLegalityCallbackFn> legalityFn;
DynamicLegalityCallbackFn legalityFn;
};
/// Get the legalization information for the given operation.
@ -841,11 +837,7 @@ private:
llvm::StringMap<DynamicLegalityCallbackFn> dialectLegalityFns;
/// An optional legality callback for unknown operations.
Optional<DynamicLegalityCallbackFn> unknownLegalityFn;
/// Flag indicating if unknown operations should be treated as dynamically
/// legal.
bool unknownOpsDynamicallyLegal;
DynamicLegalityCallbackFn unknownLegalityFn;
/// The current context this target applies to.
MLIRContext &ctx;

View File

@ -2672,7 +2672,7 @@ void mlir::populateFuncOpTypeConversionPattern(RewritePatternSet &patterns,
/// Register a legality action for the given operation.
void ConversionTarget::setOpAction(OperationName op,
LegalizationAction action) {
legalOperations[op] = {action, /*isRecursivelyLegal=*/false, llvm::None};
legalOperations[op] = {action, /*isRecursivelyLegal=*/false, nullptr};
}
/// Register a legality action for the given dialects.
@ -2703,8 +2703,7 @@ auto ConversionTarget::isLegal(Operation *op) const
// Handle dynamic legality either with the provided legality function, or
// the default hook on the derived instance.
if (info->action == LegalizationAction::Dynamic)
return info->legalityFn ? (*info->legalityFn)(op)
: isDynamicallyLegal(op);
return info->legalityFn ? info->legalityFn(op) : isDynamicallyLegal(op);
// Otherwise, the operation is only legal if it was marked 'Legal'.
return info->action == LegalizationAction::Legal;
@ -2758,6 +2757,13 @@ void ConversionTarget::setLegalityCallback(
dialectLegalityFns[dialect] = callback;
}
/// Set the dynamic legality callback for the unknown ops.
void ConversionTarget::setLegalityCallback(
const DynamicLegalityCallbackFn &callback) {
assert(callback && "expected valid legality callback");
unknownLegalityFn = callback;
}
/// Get the legalization information for the given operation.
auto ConversionTarget::getOpInfo(OperationName op) const
-> Optional<LegalizationInfo> {
@ -2768,7 +2774,7 @@ auto ConversionTarget::getOpInfo(OperationName op) const
// Check for info for the parent dialect.
auto dialectIt = legalDialects.find(op.getDialectNamespace());
if (dialectIt != legalDialects.end()) {
Optional<DynamicLegalityCallbackFn> callback;
DynamicLegalityCallbackFn callback;
auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
if (dialectFn != dialectLegalityFns.end())
callback = dialectFn->second;
@ -2776,7 +2782,7 @@ auto ConversionTarget::getOpInfo(OperationName op) const
callback};
}
// Otherwise, check if we mark unknown operations as dynamic.
if (unknownOpsDynamicallyLegal)
if (unknownLegalityFn)
return LegalizationInfo{LegalizationAction::Dynamic,
/*isRecursivelyLegal=*/false, unknownLegalityFn};
return llvm::None;