[mlir][bufferization][NFC] Move OpFilter out of BufferizationOptions

Differential Revision: https://reviews.llvm.org/D126568
This commit is contained in:
Matthias Springer 2022-05-28 01:45:55 +02:00
parent 4a36813669
commit 1534177f8f
8 changed files with 178 additions and 169 deletions

View File

@ -23,6 +23,144 @@ class AnalysisState;
class BufferizableOpInterface;
struct DialectAnalysisState;
class OpFilter {
public:
/// An op filter entry. Filters can be used to specify which ops should be
/// processed by the bufferization.
struct Entry {
/// If the filter function evaluates to `true`, the filter matches.
using FilterFn = std::function<bool(Operation *)>;
/// Filter type: A filter can either be a DENY filter or an ALLOW filter.
enum FilterType : int8_t { DENY = 0, ALLOW = 1 };
FilterFn fn;
FilterType type;
};
/// Return whether the op is allowed or not.
///
/// If the filter does not have an ALLOW rule, ops are allowed by default,
/// unless they are explicitly marked as DENY. If the filter has at least one
/// ALLOW rule, ops are denied by default and only allowed if they match
/// an ALLOW rule and no DENY rule.
bool isOpAllowed(Operation *op) const;
/// Allow the given dialects.
///
/// This function adds one or multiple ALLOW entries.
template <typename... DialectTs> void allowDialect() {
// The following expands a call to allowDialectImpl for each dialect
// in 'DialectTs'. This magic is necessary due to a limitation in the places
// that a parameter pack can be expanded in c++11.
// FIXME: In c++17 this can be simplified by using 'fold expressions'.
(void)std::initializer_list<int>{0, (allowDialectImpl<DialectTs>(), 0)...};
}
/// Deny the given dialects.
///
/// This function adds one or multiple DENY entries.
template <typename... DialectTs> void denyDialect() {
// FIXME: In c++17 this can be simplified by using 'fold expressions'.
(void)std::initializer_list<int>{0, (denyDialectImpl<DialectTs>(), 0)...};
}
/// Allow the given dialect.
///
/// This function adds an ALLOW entry.
void allowDialect(StringRef dialectNamespace) {
Entry::FilterFn filterFn = [=](Operation *op) {
return op->getDialect()->getNamespace() == dialectNamespace;
};
entries.push_back(Entry{filterFn, Entry::FilterType::ALLOW});
}
/// Allow the given ops.
///
/// This function adds one or multiple ALLOW entries.
template <typename... OpTys> void allowOperation() {
// FIXME: In c++17 this can be simplified by using 'fold expressions'.
(void)std::initializer_list<int>{0, (allowOperationImpl<OpTys>(), 0)...};
}
/// Deny the given ops.
///
/// This function adds one or multiple DENY entries.
template <typename... OpTys> void denyOperation() {
// FIXME: In c++17 this can be simplified by using 'fold expressions'.
(void)std::initializer_list<int>{0, (denyOperationImpl<OpTys>(), 0)...};
}
/// Allow the given op.
///
/// This function adds an ALLOW entry.
void allowOperation(StringRef opName) {
Entry::FilterFn filterFn = [=](Operation *op) {
return op->getName().getStringRef() == opName;
};
allowOperation(filterFn);
}
/// Deny the given op.
///
/// This function adds a DENY entry.
void denyOperation(StringRef opName) {
Entry::FilterFn filterFn = [=](Operation *op) {
return op->getName().getStringRef() == opName;
};
denyOperation(filterFn);
}
/// Allow ops that are matched by `fn`.
///
/// This function adds an ALLOW entry.
void allowOperation(Entry::FilterFn fn) {
entries.push_back(Entry{fn, Entry::FilterType::ALLOW});
}
/// Deny ops that are matched by `fn`.
///
/// This function adds a DENY entry.
void denyOperation(Entry::FilterFn fn) {
entries.push_back(Entry{fn, Entry::FilterType::DENY});
}
private:
/// Return `true` if the filter has at least one ALLOW rule.
bool hasAllowRule() const {
for (const Entry &e : entries)
if (e.type == Entry::FilterType::ALLOW)
return true;
return false;
}
/// Allow a dialect.
template <typename DialectT> void allowDialectImpl() {
allowDialect(DialectT::getDialectNamespace());
}
/// Deny a dialect.
template <typename DialectT> void denyDialectImpl() {
denyDialect(DialectT::getDialectNamespace());
}
/// Allow an op.
template <typename OpTy> void allowOperationImpl() {
allowOperation(OpTy::getOperationName());
}
/// Deny an op.
template <typename OpTy> void denyOperationImpl() {
denyOperation(OpTy::getOperationName());
}
/// A list of filter entries that determine whether an op should be allowed or
/// denied. If the filter has an ALLOW rule, only ops that are allowed and not
/// denied are allowed. If the filter does not have an ALLOW rule, only ops
/// that are not denied are allowed.
SmallVector<Entry> entries;
};
/// Options for BufferizableOpInterface-based bufferization.
struct BufferizationOptions {
/// Allocator function: Generate a memref allocation with the given type,
@ -42,19 +180,6 @@ struct BufferizationOptions {
using DialectStateInitFn =
std::function<std::unique_ptr<DialectAnalysisState>()>;
/// An op filter entry. Filters can be used to specify which ops should be
/// processed by the bufferization.
struct OpFilterEntry {
/// If the filter function evaluates to `true`, the filter matches.
using FilterFn = std::function<bool(Operation *)>;
/// Filter type: A filter can either be a DENY filter or an ALLOW filter.
enum FilterType : int8_t { DENY = 0, ALLOW = 1 };
FilterFn fn;
FilterType type;
};
enum class LayoutMapOption : int8_t {
InferLayoutMap = 0,
IdentityLayoutMap = 1,
@ -63,108 +188,6 @@ struct BufferizationOptions {
BufferizationOptions();
/// Return `true` if the filter has at least one ALLOW rule.
bool filterHasAllowRule() const {
for (const OpFilterEntry &e : opFilter)
if (e.type == OpFilterEntry::FilterType::ALLOW)
return true;
return false;
}
/// Return whether the op should be bufferized or not.
///
/// If the filter does not have an ALLOW rule, ops are bufferized by default,
/// unless they are explicitly marked as DENY. If the filter has at least one
/// ALLOW rule, ops are ignored by default and only bufferized if they match
/// an ALLOW rule and no DENY rule.
bool isOpAllowed(Operation *op) const;
/// Allow the given dialects in the filter.
///
/// This function adds one or multiple ALLOW filters.
template <typename... DialectTs>
void allowDialectInFilter() {
// The following expands a call to allowDialectInFilterImpl for each dialect
// in 'DialectTs'. This magic is necessary due to a limitation in the places
// that a parameter pack can be expanded in c++11.
// FIXME: In c++17 this can be simplified by using 'fold expressions'.
(void)std::initializer_list<int>{
0, (allowDialectInFilterImpl<DialectTs>(), 0)...};
}
/// Deny the given dialects in the filter.
///
/// This function adds one or multiple DENY filters.
template <typename... DialectTs> void denyDialectInFilter() {
// FIXME: In c++17 this can be simplified by using 'fold expressions'.
(void)std::initializer_list<int>{
0, (denyDialectInFilterImpl<DialectTs>(), 0)...};
}
/// Allow the given dialect in the filter.
///
/// This function adds an ALLOW filter.
void allowDialectInFilter(StringRef dialectNamespace) {
OpFilterEntry::FilterFn filterFn = [=](Operation *op) {
return op->getDialect()->getNamespace() == dialectNamespace;
};
opFilter.push_back(
OpFilterEntry{filterFn, OpFilterEntry::FilterType::ALLOW});
}
/// Allow the given ops in the filter.
///
/// This function adds one or multiple ALLOW filters.
template <typename... OpTys>
void allowOperationInFilter() {
// FIXME: In c++17 this can be simplified by using 'fold expressions'.
(void)std::initializer_list<int>{
0, (allowOperationInFilterImpl<OpTys>(), 0)...};
}
/// Deny the given ops in the filter.
///
/// This function adds one or multiple DENY filters.
template <typename... OpTys> void denyOperationInFilter() {
// FIXME: In c++17 this can be simplified by using 'fold expressions'.
(void)std::initializer_list<int>{
0, (denyOperationInFilterImpl<OpTys>(), 0)...};
}
/// Allow the given op in the filter.
///
/// This function adds an ALLOW filter.
void allowOperationInFilter(StringRef opName) {
OpFilterEntry::FilterFn filterFn = [=](Operation *op) {
return op->getName().getStringRef() == opName;
};
allowOperationInFilter(filterFn);
}
/// Deny the given op in the filter.
///
/// This function adds a DENY filter.
void denyOperationInFilter(StringRef opName) {
OpFilterEntry::FilterFn filterFn = [=](Operation *op) {
return op->getName().getStringRef() == opName;
};
denyOperationInFilter(filterFn);
}
/// Allow ops that are matched by `fn` in the filter.
///
/// This function adds an ALLOW filter.
void allowOperationInFilter(OpFilterEntry::FilterFn fn) {
opFilter.push_back(OpFilterEntry{fn, OpFilterEntry::FilterType::ALLOW});
}
/// Deny ops that are matched by `fn` in the filter.
///
/// This function adds a DENY filter.
void denyOperationInFilter(OpFilterEntry::FilterFn fn) {
opFilter.push_back(OpFilterEntry{fn, OpFilterEntry::FilterType::DENY});
}
/// Try to cast the given op to BufferizableOpInterface if the op is allow
/// listed.
BufferizableOpInterface dynCastBufferizableOp(Operation *op) const;
@ -173,6 +196,13 @@ struct BufferizationOptions {
/// listed.
BufferizableOpInterface dynCastBufferizableOp(Value value) const;
/// A filter that specifies which ops should be bufferized and which ops
/// should be ignored.
OpFilter opFilter;
/// Return `true` if the given op should be bufferized.
bool isOpAllowed(Operation *op) const;
/// Helper functions for allocation, deallocation, memory copying.
Optional<AllocationFn> allocationFn;
Optional<DeallocationFn> deallocationFn;
@ -276,12 +306,6 @@ struct BufferizationOptions {
/// Buffer alignment for new memory allocations.
unsigned int bufferAlignment = 128;
/// A list of op filters that determine whether an op should be processed or
/// ignored by the bufferization. If the filter has an ALLOW rule, only ops
/// that are allowed and not denied are bufferized. If the filter does not
/// have an ALLOW rule, only ops that are not denied are bufferized.
SmallVector<OpFilterEntry> opFilter;
/// Initializer functions for analysis state. These can be used to
/// initialize dialect-specific analysis state.
SmallVector<AnalysisStateInitFn> stateInitializers;
@ -289,29 +313,6 @@ struct BufferizationOptions {
/// Add a analysis state initializer that initializes the specified
/// dialect-specific analysis state.
void addDialectStateInitializer(StringRef name, const DialectStateInitFn &fn);
private:
/// Allow a dialect.
template <typename DialectT>
void allowDialectInFilterImpl() {
allowDialectInFilter(DialectT::getDialectNamespace());
}
/// Deny a dialect.
template <typename DialectT> void denyDialectInFilterImpl() {
denyDialectInFilter(DialectT::getDialectNamespace());
}
/// Allow an op.
template <typename OpTy>
void allowOperationInFilterImpl() {
allowOperationInFilter(OpTy::getOperationName());
}
/// Deny an op.
template <typename OpTy> void denyOperationInFilterImpl() {
denyOperationInFilter(OpTy::getOperationName());
}
};
/// Specify fine-grain relationship between buffers to enable more analysis.

View File

@ -32,9 +32,9 @@ struct ArithmeticBufferizePass
void runOnOperation() override {
BufferizationOptions options = getPartialBufferizationOptions();
if (constantOpOnly) {
options.allowOperationInFilter<arith::ConstantOp>();
options.opFilter.allowOperation<arith::ConstantOp>();
} else {
options.allowDialectInFilter<arith::ArithmeticDialect>();
options.opFilter.allowDialect<arith::ArithmeticDialect>();
}
options.bufferAlignment = alignment;

View File

@ -44,6 +44,29 @@ static const char *kBufferAllocationAttr = "bufferization.allocation";
/// Attribute name used to mark allocs that should not be deallocated.
static const char *kSkipDeallocAttr = "bufferization.skip_dealloc";
//===----------------------------------------------------------------------===//
// OpFilter
//===----------------------------------------------------------------------===//
bool OpFilter::isOpAllowed(Operation *op) const {
// All other ops: Allow/disallow according to filter.
bool isAllowed = !hasAllowRule();
for (const Entry &entry : entries) {
bool filterResult = entry.fn(op);
switch (entry.type) {
case Entry::ALLOW:
isAllowed |= filterResult;
break;
case Entry::DENY:
if (filterResult)
// DENY filter matches. This op is no allowed. (Even if other ALLOW
// filters may match.)
return false;
};
}
return isAllowed;
}
//===----------------------------------------------------------------------===//
// BufferizationOptions
//===----------------------------------------------------------------------===//
@ -58,22 +81,7 @@ bool BufferizationOptions::isOpAllowed(Operation *op) const {
if (!bufferizeFunctionBoundaries && isFuncBoundaryOp)
return false;
// All other ops: Allow/disallow according to filter.
bool isAllowed = !filterHasAllowRule();
for (const OpFilterEntry &entry : opFilter) {
bool filterResult = entry.fn(op);
switch (entry.type) {
case OpFilterEntry::ALLOW:
isAllowed |= filterResult;
break;
case OpFilterEntry::DENY:
if (filterResult)
// DENY filter matches. This op is no allowed. (Even if other ALLOW
// filters may match.)
return false;
};
}
return isAllowed;
return opFilter.isOpAllowed(op);
}
BufferizableOpInterface

View File

@ -194,7 +194,7 @@ struct OneShotBufferizePass
opt.promoteBufferResultsToOutParams = promoteBufferResultsToOutParams;
opt.unknownTypeConversion = parseLayoutMapOption(unknownTypeConversion);
BufferizationOptions::OpFilterEntry::FilterFn filterFn =
OpFilter::Entry::FilterFn filterFn =
[&](Operation *op) {
// Filter may be specified via options.
if (this->dialectFilter.hasValue())
@ -204,7 +204,7 @@ struct OneShotBufferizePass
// No filter specified: All other ops are allowed.
return true;
};
opt.allowOperationInFilter(filterFn);
opt.opFilter.allowOperation(filterFn);
} else {
opt = *options;
}
@ -242,7 +242,7 @@ struct BufferizationBufferizePass
: public BufferizationBufferizeBase<BufferizationBufferizePass> {
void runOnOperation() override {
BufferizationOptions options = getPartialBufferizationOptions();
options.allowDialectInFilter<BufferizationDialect>();
options.opFilter.allowDialect<BufferizationDialect>();
if (failed(bufferizeOp(getOperation(), options)))
signalPassFailure();

View File

@ -28,7 +28,7 @@ namespace {
struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> {
void runOnOperation() override {
BufferizationOptions options = getPartialBufferizationOptions();
options.allowDialectInFilter<linalg::LinalgDialect>();
options.opFilter.allowDialect<linalg::LinalgDialect>();
if (failed(bufferizeOp(getOperation(), options)))
signalPassFailure();

View File

@ -23,7 +23,7 @@ namespace {
struct ShapeBufferizePass : public ShapeBufferizeBase<ShapeBufferizePass> {
void runOnOperation() override {
BufferizationOptions options = getPartialBufferizationOptions();
options.allowDialectInFilter<shape::ShapeDialect>();
options.opFilter.allowDialect<shape::ShapeDialect>();
if (failed(bufferizeOp(getOperation(), options)))
signalPassFailure();

View File

@ -30,7 +30,7 @@ namespace {
struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> {
void runOnOperation() override {
BufferizationOptions options = getPartialBufferizationOptions();
options.allowDialectInFilter<tensor::TensorDialect>();
options.opFilter.allowDialect<tensor::TensorDialect>();
if (failed(bufferizeOp(getOperation(), options)))
signalPassFailure();

View File

@ -27,7 +27,7 @@ namespace {
struct VectorBufferizePass : public VectorBufferizeBase<VectorBufferizePass> {
void runOnOperation() override {
BufferizationOptions options = getPartialBufferizationOptions();
options.allowDialectInFilter<vector::VectorDialect>();
options.opFilter.allowDialect<vector::VectorDialect>();
if (failed(bufferizeOp(getOperation(), options)))
signalPassFailure();