forked from OSchip/llvm-project
[mlir] ConversionTarget legality callbacks refactoring
* Get rid of Optional<std::function> as std::function already have a null state * Add private setLegalityCallback function to set legality callback for unknown ops * Get rid of unknownOpsDynamicallyLegal flag, use unknownLegalityFn state insted. This causes behavior change when user first calls markUnknownOpDynamicallyLegal with callback and then without but I am not sure is the original behavior was really a 'feature', or just oversignt in the original implementation. Differential Revision: https://reviews.llvm.org/D105496
This commit is contained in:
parent
05ae303555
commit
b7a4649899
|
@ -621,8 +621,7 @@ public:
|
|||
/// dynamically legal on the target.
|
||||
using DynamicLegalityCallbackFn = std::function<bool(Operation *)>;
|
||||
|
||||
ConversionTarget(MLIRContext &ctx)
|
||||
: unknownOpsDynamicallyLegal(false), ctx(ctx) {}
|
||||
ConversionTarget(MLIRContext &ctx) : ctx(ctx) {}
|
||||
virtual ~ConversionTarget() = default;
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
@ -739,18 +738,11 @@ public:
|
|||
setDialectAction(dialectNames, LegalizationAction::Dynamic);
|
||||
}
|
||||
template <typename... Args>
|
||||
void addDynamicallyLegalDialect(
|
||||
Optional<DynamicLegalityCallbackFn> callback = llvm::None) {
|
||||
void addDynamicallyLegalDialect(DynamicLegalityCallbackFn callback = {}) {
|
||||
SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
|
||||
setDialectAction(dialectNames, LegalizationAction::Dynamic);
|
||||
if (callback)
|
||||
setLegalityCallback(dialectNames, *callback);
|
||||
}
|
||||
template <typename... Args>
|
||||
void addDynamicallyLegalDialect(DynamicLegalityCallbackFn callback) {
|
||||
SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
|
||||
setDialectAction(dialectNames, LegalizationAction::Dynamic);
|
||||
setLegalityCallback(dialectNames, callback);
|
||||
setLegalityCallback(dialectNames, callback);
|
||||
}
|
||||
|
||||
/// Register unknown operations as dynamically legal. For operations(and
|
||||
|
@ -758,10 +750,11 @@ public:
|
|||
/// dynamically legal and invoke the given callback if valid or
|
||||
/// 'isDynamicallyLegal'.
|
||||
void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn) {
|
||||
unknownOpsDynamicallyLegal = true;
|
||||
unknownLegalityFn = fn;
|
||||
setLegalityCallback(fn);
|
||||
}
|
||||
void markUnknownOpDynamicallyLegal() {
|
||||
setLegalityCallback([](Operation *) { return true; });
|
||||
}
|
||||
void markUnknownOpDynamicallyLegal() { unknownOpsDynamicallyLegal = true; }
|
||||
|
||||
/// Register the operations of the given dialects as illegal, i.e.
|
||||
/// operations of this dialect are not supported by the target.
|
||||
|
@ -805,6 +798,9 @@ private:
|
|||
void setLegalityCallback(ArrayRef<StringRef> dialects,
|
||||
const DynamicLegalityCallbackFn &callback);
|
||||
|
||||
/// Set the dynamic legality callback for the unknown ops.
|
||||
void setLegalityCallback(const DynamicLegalityCallbackFn &callback);
|
||||
|
||||
/// Set the recursive legality callback for the given operation and mark the
|
||||
/// operation as recursively legal.
|
||||
void markOpRecursivelyLegal(OperationName name,
|
||||
|
@ -819,7 +815,7 @@ private:
|
|||
bool isRecursivelyLegal;
|
||||
|
||||
/// The legality callback if this operation is dynamically legal.
|
||||
Optional<DynamicLegalityCallbackFn> legalityFn;
|
||||
DynamicLegalityCallbackFn legalityFn;
|
||||
};
|
||||
|
||||
/// Get the legalization information for the given operation.
|
||||
|
@ -841,11 +837,7 @@ private:
|
|||
llvm::StringMap<DynamicLegalityCallbackFn> dialectLegalityFns;
|
||||
|
||||
/// An optional legality callback for unknown operations.
|
||||
Optional<DynamicLegalityCallbackFn> unknownLegalityFn;
|
||||
|
||||
/// Flag indicating if unknown operations should be treated as dynamically
|
||||
/// legal.
|
||||
bool unknownOpsDynamicallyLegal;
|
||||
DynamicLegalityCallbackFn unknownLegalityFn;
|
||||
|
||||
/// The current context this target applies to.
|
||||
MLIRContext &ctx;
|
||||
|
|
|
@ -2672,7 +2672,7 @@ void mlir::populateFuncOpTypeConversionPattern(RewritePatternSet &patterns,
|
|||
/// Register a legality action for the given operation.
|
||||
void ConversionTarget::setOpAction(OperationName op,
|
||||
LegalizationAction action) {
|
||||
legalOperations[op] = {action, /*isRecursivelyLegal=*/false, llvm::None};
|
||||
legalOperations[op] = {action, /*isRecursivelyLegal=*/false, nullptr};
|
||||
}
|
||||
|
||||
/// Register a legality action for the given dialects.
|
||||
|
@ -2703,8 +2703,7 @@ auto ConversionTarget::isLegal(Operation *op) const
|
|||
// Handle dynamic legality either with the provided legality function, or
|
||||
// the default hook on the derived instance.
|
||||
if (info->action == LegalizationAction::Dynamic)
|
||||
return info->legalityFn ? (*info->legalityFn)(op)
|
||||
: isDynamicallyLegal(op);
|
||||
return info->legalityFn ? info->legalityFn(op) : isDynamicallyLegal(op);
|
||||
|
||||
// Otherwise, the operation is only legal if it was marked 'Legal'.
|
||||
return info->action == LegalizationAction::Legal;
|
||||
|
@ -2758,6 +2757,13 @@ void ConversionTarget::setLegalityCallback(
|
|||
dialectLegalityFns[dialect] = callback;
|
||||
}
|
||||
|
||||
/// Set the dynamic legality callback for the unknown ops.
|
||||
void ConversionTarget::setLegalityCallback(
|
||||
const DynamicLegalityCallbackFn &callback) {
|
||||
assert(callback && "expected valid legality callback");
|
||||
unknownLegalityFn = callback;
|
||||
}
|
||||
|
||||
/// Get the legalization information for the given operation.
|
||||
auto ConversionTarget::getOpInfo(OperationName op) const
|
||||
-> Optional<LegalizationInfo> {
|
||||
|
@ -2768,7 +2774,7 @@ auto ConversionTarget::getOpInfo(OperationName op) const
|
|||
// Check for info for the parent dialect.
|
||||
auto dialectIt = legalDialects.find(op.getDialectNamespace());
|
||||
if (dialectIt != legalDialects.end()) {
|
||||
Optional<DynamicLegalityCallbackFn> callback;
|
||||
DynamicLegalityCallbackFn callback;
|
||||
auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
|
||||
if (dialectFn != dialectLegalityFns.end())
|
||||
callback = dialectFn->second;
|
||||
|
@ -2776,7 +2782,7 @@ auto ConversionTarget::getOpInfo(OperationName op) const
|
|||
callback};
|
||||
}
|
||||
// Otherwise, check if we mark unknown operations as dynamic.
|
||||
if (unknownOpsDynamicallyLegal)
|
||||
if (unknownLegalityFn)
|
||||
return LegalizationInfo{LegalizationAction::Dynamic,
|
||||
/*isRecursivelyLegal=*/false, unknownLegalityFn};
|
||||
return llvm::None;
|
||||
|
|
Loading…
Reference in New Issue