forked from OSchip/llvm-project
[mlir][bufferization][NFC] Move OpFilter out of BufferizationOptions
Differential Revision: https://reviews.llvm.org/D126568
This commit is contained in:
parent
4a36813669
commit
1534177f8f
|
@ -23,6 +23,144 @@ class AnalysisState;
|
|||
class BufferizableOpInterface;
|
||||
struct DialectAnalysisState;
|
||||
|
||||
class OpFilter {
|
||||
public:
|
||||
/// An op filter entry. Filters can be used to specify which ops should be
|
||||
/// processed by the bufferization.
|
||||
struct Entry {
|
||||
/// If the filter function evaluates to `true`, the filter matches.
|
||||
using FilterFn = std::function<bool(Operation *)>;
|
||||
|
||||
/// Filter type: A filter can either be a DENY filter or an ALLOW filter.
|
||||
enum FilterType : int8_t { DENY = 0, ALLOW = 1 };
|
||||
|
||||
FilterFn fn;
|
||||
FilterType type;
|
||||
};
|
||||
|
||||
/// Return whether the op is allowed or not.
|
||||
///
|
||||
/// If the filter does not have an ALLOW rule, ops are allowed by default,
|
||||
/// unless they are explicitly marked as DENY. If the filter has at least one
|
||||
/// ALLOW rule, ops are denied by default and only allowed if they match
|
||||
/// an ALLOW rule and no DENY rule.
|
||||
bool isOpAllowed(Operation *op) const;
|
||||
|
||||
/// Allow the given dialects.
|
||||
///
|
||||
/// This function adds one or multiple ALLOW entries.
|
||||
template <typename... DialectTs> void allowDialect() {
|
||||
// The following expands a call to allowDialectImpl for each dialect
|
||||
// in 'DialectTs'. This magic is necessary due to a limitation in the places
|
||||
// that a parameter pack can be expanded in c++11.
|
||||
// FIXME: In c++17 this can be simplified by using 'fold expressions'.
|
||||
(void)std::initializer_list<int>{0, (allowDialectImpl<DialectTs>(), 0)...};
|
||||
}
|
||||
|
||||
/// Deny the given dialects.
|
||||
///
|
||||
/// This function adds one or multiple DENY entries.
|
||||
template <typename... DialectTs> void denyDialect() {
|
||||
// FIXME: In c++17 this can be simplified by using 'fold expressions'.
|
||||
(void)std::initializer_list<int>{0, (denyDialectImpl<DialectTs>(), 0)...};
|
||||
}
|
||||
|
||||
/// Allow the given dialect.
|
||||
///
|
||||
/// This function adds an ALLOW entry.
|
||||
void allowDialect(StringRef dialectNamespace) {
|
||||
Entry::FilterFn filterFn = [=](Operation *op) {
|
||||
return op->getDialect()->getNamespace() == dialectNamespace;
|
||||
};
|
||||
entries.push_back(Entry{filterFn, Entry::FilterType::ALLOW});
|
||||
}
|
||||
|
||||
/// Allow the given ops.
|
||||
///
|
||||
/// This function adds one or multiple ALLOW entries.
|
||||
template <typename... OpTys> void allowOperation() {
|
||||
// FIXME: In c++17 this can be simplified by using 'fold expressions'.
|
||||
(void)std::initializer_list<int>{0, (allowOperationImpl<OpTys>(), 0)...};
|
||||
}
|
||||
|
||||
/// Deny the given ops.
|
||||
///
|
||||
/// This function adds one or multiple DENY entries.
|
||||
template <typename... OpTys> void denyOperation() {
|
||||
// FIXME: In c++17 this can be simplified by using 'fold expressions'.
|
||||
(void)std::initializer_list<int>{0, (denyOperationImpl<OpTys>(), 0)...};
|
||||
}
|
||||
|
||||
/// Allow the given op.
|
||||
///
|
||||
/// This function adds an ALLOW entry.
|
||||
void allowOperation(StringRef opName) {
|
||||
Entry::FilterFn filterFn = [=](Operation *op) {
|
||||
return op->getName().getStringRef() == opName;
|
||||
};
|
||||
allowOperation(filterFn);
|
||||
}
|
||||
|
||||
/// Deny the given op.
|
||||
///
|
||||
/// This function adds a DENY entry.
|
||||
void denyOperation(StringRef opName) {
|
||||
Entry::FilterFn filterFn = [=](Operation *op) {
|
||||
return op->getName().getStringRef() == opName;
|
||||
};
|
||||
denyOperation(filterFn);
|
||||
}
|
||||
|
||||
/// Allow ops that are matched by `fn`.
|
||||
///
|
||||
/// This function adds an ALLOW entry.
|
||||
void allowOperation(Entry::FilterFn fn) {
|
||||
entries.push_back(Entry{fn, Entry::FilterType::ALLOW});
|
||||
}
|
||||
|
||||
/// Deny ops that are matched by `fn`.
|
||||
///
|
||||
/// This function adds a DENY entry.
|
||||
void denyOperation(Entry::FilterFn fn) {
|
||||
entries.push_back(Entry{fn, Entry::FilterType::DENY});
|
||||
}
|
||||
|
||||
private:
|
||||
/// Return `true` if the filter has at least one ALLOW rule.
|
||||
bool hasAllowRule() const {
|
||||
for (const Entry &e : entries)
|
||||
if (e.type == Entry::FilterType::ALLOW)
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Allow a dialect.
|
||||
template <typename DialectT> void allowDialectImpl() {
|
||||
allowDialect(DialectT::getDialectNamespace());
|
||||
}
|
||||
|
||||
/// Deny a dialect.
|
||||
template <typename DialectT> void denyDialectImpl() {
|
||||
denyDialect(DialectT::getDialectNamespace());
|
||||
}
|
||||
|
||||
/// Allow an op.
|
||||
template <typename OpTy> void allowOperationImpl() {
|
||||
allowOperation(OpTy::getOperationName());
|
||||
}
|
||||
|
||||
/// Deny an op.
|
||||
template <typename OpTy> void denyOperationImpl() {
|
||||
denyOperation(OpTy::getOperationName());
|
||||
}
|
||||
|
||||
/// A list of filter entries that determine whether an op should be allowed or
|
||||
/// denied. If the filter has an ALLOW rule, only ops that are allowed and not
|
||||
/// denied are allowed. If the filter does not have an ALLOW rule, only ops
|
||||
/// that are not denied are allowed.
|
||||
SmallVector<Entry> entries;
|
||||
};
|
||||
|
||||
/// Options for BufferizableOpInterface-based bufferization.
|
||||
struct BufferizationOptions {
|
||||
/// Allocator function: Generate a memref allocation with the given type,
|
||||
|
@ -42,19 +180,6 @@ struct BufferizationOptions {
|
|||
using DialectStateInitFn =
|
||||
std::function<std::unique_ptr<DialectAnalysisState>()>;
|
||||
|
||||
/// An op filter entry. Filters can be used to specify which ops should be
|
||||
/// processed by the bufferization.
|
||||
struct OpFilterEntry {
|
||||
/// If the filter function evaluates to `true`, the filter matches.
|
||||
using FilterFn = std::function<bool(Operation *)>;
|
||||
|
||||
/// Filter type: A filter can either be a DENY filter or an ALLOW filter.
|
||||
enum FilterType : int8_t { DENY = 0, ALLOW = 1 };
|
||||
|
||||
FilterFn fn;
|
||||
FilterType type;
|
||||
};
|
||||
|
||||
enum class LayoutMapOption : int8_t {
|
||||
InferLayoutMap = 0,
|
||||
IdentityLayoutMap = 1,
|
||||
|
@ -63,108 +188,6 @@ struct BufferizationOptions {
|
|||
|
||||
BufferizationOptions();
|
||||
|
||||
/// Return `true` if the filter has at least one ALLOW rule.
|
||||
bool filterHasAllowRule() const {
|
||||
for (const OpFilterEntry &e : opFilter)
|
||||
if (e.type == OpFilterEntry::FilterType::ALLOW)
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Return whether the op should be bufferized or not.
|
||||
///
|
||||
/// If the filter does not have an ALLOW rule, ops are bufferized by default,
|
||||
/// unless they are explicitly marked as DENY. If the filter has at least one
|
||||
/// ALLOW rule, ops are ignored by default and only bufferized if they match
|
||||
/// an ALLOW rule and no DENY rule.
|
||||
bool isOpAllowed(Operation *op) const;
|
||||
|
||||
/// Allow the given dialects in the filter.
|
||||
///
|
||||
/// This function adds one or multiple ALLOW filters.
|
||||
template <typename... DialectTs>
|
||||
void allowDialectInFilter() {
|
||||
// The following expands a call to allowDialectInFilterImpl for each dialect
|
||||
// in 'DialectTs'. This magic is necessary due to a limitation in the places
|
||||
// that a parameter pack can be expanded in c++11.
|
||||
// FIXME: In c++17 this can be simplified by using 'fold expressions'.
|
||||
(void)std::initializer_list<int>{
|
||||
0, (allowDialectInFilterImpl<DialectTs>(), 0)...};
|
||||
}
|
||||
|
||||
/// Deny the given dialects in the filter.
|
||||
///
|
||||
/// This function adds one or multiple DENY filters.
|
||||
template <typename... DialectTs> void denyDialectInFilter() {
|
||||
// FIXME: In c++17 this can be simplified by using 'fold expressions'.
|
||||
(void)std::initializer_list<int>{
|
||||
0, (denyDialectInFilterImpl<DialectTs>(), 0)...};
|
||||
}
|
||||
|
||||
/// Allow the given dialect in the filter.
|
||||
///
|
||||
/// This function adds an ALLOW filter.
|
||||
void allowDialectInFilter(StringRef dialectNamespace) {
|
||||
OpFilterEntry::FilterFn filterFn = [=](Operation *op) {
|
||||
return op->getDialect()->getNamespace() == dialectNamespace;
|
||||
};
|
||||
opFilter.push_back(
|
||||
OpFilterEntry{filterFn, OpFilterEntry::FilterType::ALLOW});
|
||||
}
|
||||
|
||||
/// Allow the given ops in the filter.
|
||||
///
|
||||
/// This function adds one or multiple ALLOW filters.
|
||||
template <typename... OpTys>
|
||||
void allowOperationInFilter() {
|
||||
// FIXME: In c++17 this can be simplified by using 'fold expressions'.
|
||||
(void)std::initializer_list<int>{
|
||||
0, (allowOperationInFilterImpl<OpTys>(), 0)...};
|
||||
}
|
||||
|
||||
/// Deny the given ops in the filter.
|
||||
///
|
||||
/// This function adds one or multiple DENY filters.
|
||||
template <typename... OpTys> void denyOperationInFilter() {
|
||||
// FIXME: In c++17 this can be simplified by using 'fold expressions'.
|
||||
(void)std::initializer_list<int>{
|
||||
0, (denyOperationInFilterImpl<OpTys>(), 0)...};
|
||||
}
|
||||
|
||||
/// Allow the given op in the filter.
|
||||
///
|
||||
/// This function adds an ALLOW filter.
|
||||
void allowOperationInFilter(StringRef opName) {
|
||||
OpFilterEntry::FilterFn filterFn = [=](Operation *op) {
|
||||
return op->getName().getStringRef() == opName;
|
||||
};
|
||||
allowOperationInFilter(filterFn);
|
||||
}
|
||||
|
||||
/// Deny the given op in the filter.
|
||||
///
|
||||
/// This function adds a DENY filter.
|
||||
void denyOperationInFilter(StringRef opName) {
|
||||
OpFilterEntry::FilterFn filterFn = [=](Operation *op) {
|
||||
return op->getName().getStringRef() == opName;
|
||||
};
|
||||
denyOperationInFilter(filterFn);
|
||||
}
|
||||
|
||||
/// Allow ops that are matched by `fn` in the filter.
|
||||
///
|
||||
/// This function adds an ALLOW filter.
|
||||
void allowOperationInFilter(OpFilterEntry::FilterFn fn) {
|
||||
opFilter.push_back(OpFilterEntry{fn, OpFilterEntry::FilterType::ALLOW});
|
||||
}
|
||||
|
||||
/// Deny ops that are matched by `fn` in the filter.
|
||||
///
|
||||
/// This function adds a DENY filter.
|
||||
void denyOperationInFilter(OpFilterEntry::FilterFn fn) {
|
||||
opFilter.push_back(OpFilterEntry{fn, OpFilterEntry::FilterType::DENY});
|
||||
}
|
||||
|
||||
/// Try to cast the given op to BufferizableOpInterface if the op is allow
|
||||
/// listed.
|
||||
BufferizableOpInterface dynCastBufferizableOp(Operation *op) const;
|
||||
|
@ -173,6 +196,13 @@ struct BufferizationOptions {
|
|||
/// listed.
|
||||
BufferizableOpInterface dynCastBufferizableOp(Value value) const;
|
||||
|
||||
/// A filter that specifies which ops should be bufferized and which ops
|
||||
/// should be ignored.
|
||||
OpFilter opFilter;
|
||||
|
||||
/// Return `true` if the given op should be bufferized.
|
||||
bool isOpAllowed(Operation *op) const;
|
||||
|
||||
/// Helper functions for allocation, deallocation, memory copying.
|
||||
Optional<AllocationFn> allocationFn;
|
||||
Optional<DeallocationFn> deallocationFn;
|
||||
|
@ -276,12 +306,6 @@ struct BufferizationOptions {
|
|||
/// Buffer alignment for new memory allocations.
|
||||
unsigned int bufferAlignment = 128;
|
||||
|
||||
/// A list of op filters that determine whether an op should be processed or
|
||||
/// ignored by the bufferization. If the filter has an ALLOW rule, only ops
|
||||
/// that are allowed and not denied are bufferized. If the filter does not
|
||||
/// have an ALLOW rule, only ops that are not denied are bufferized.
|
||||
SmallVector<OpFilterEntry> opFilter;
|
||||
|
||||
/// Initializer functions for analysis state. These can be used to
|
||||
/// initialize dialect-specific analysis state.
|
||||
SmallVector<AnalysisStateInitFn> stateInitializers;
|
||||
|
@ -289,29 +313,6 @@ struct BufferizationOptions {
|
|||
/// Add a analysis state initializer that initializes the specified
|
||||
/// dialect-specific analysis state.
|
||||
void addDialectStateInitializer(StringRef name, const DialectStateInitFn &fn);
|
||||
|
||||
private:
|
||||
/// Allow a dialect.
|
||||
template <typename DialectT>
|
||||
void allowDialectInFilterImpl() {
|
||||
allowDialectInFilter(DialectT::getDialectNamespace());
|
||||
}
|
||||
|
||||
/// Deny a dialect.
|
||||
template <typename DialectT> void denyDialectInFilterImpl() {
|
||||
denyDialectInFilter(DialectT::getDialectNamespace());
|
||||
}
|
||||
|
||||
/// Allow an op.
|
||||
template <typename OpTy>
|
||||
void allowOperationInFilterImpl() {
|
||||
allowOperationInFilter(OpTy::getOperationName());
|
||||
}
|
||||
|
||||
/// Deny an op.
|
||||
template <typename OpTy> void denyOperationInFilterImpl() {
|
||||
denyOperationInFilter(OpTy::getOperationName());
|
||||
}
|
||||
};
|
||||
|
||||
/// Specify fine-grain relationship between buffers to enable more analysis.
|
||||
|
|
|
@ -32,9 +32,9 @@ struct ArithmeticBufferizePass
|
|||
void runOnOperation() override {
|
||||
BufferizationOptions options = getPartialBufferizationOptions();
|
||||
if (constantOpOnly) {
|
||||
options.allowOperationInFilter<arith::ConstantOp>();
|
||||
options.opFilter.allowOperation<arith::ConstantOp>();
|
||||
} else {
|
||||
options.allowDialectInFilter<arith::ArithmeticDialect>();
|
||||
options.opFilter.allowDialect<arith::ArithmeticDialect>();
|
||||
}
|
||||
options.bufferAlignment = alignment;
|
||||
|
||||
|
|
|
@ -44,6 +44,29 @@ static const char *kBufferAllocationAttr = "bufferization.allocation";
|
|||
/// Attribute name used to mark allocs that should not be deallocated.
|
||||
static const char *kSkipDeallocAttr = "bufferization.skip_dealloc";
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// OpFilter
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool OpFilter::isOpAllowed(Operation *op) const {
|
||||
// All other ops: Allow/disallow according to filter.
|
||||
bool isAllowed = !hasAllowRule();
|
||||
for (const Entry &entry : entries) {
|
||||
bool filterResult = entry.fn(op);
|
||||
switch (entry.type) {
|
||||
case Entry::ALLOW:
|
||||
isAllowed |= filterResult;
|
||||
break;
|
||||
case Entry::DENY:
|
||||
if (filterResult)
|
||||
// DENY filter matches. This op is no allowed. (Even if other ALLOW
|
||||
// filters may match.)
|
||||
return false;
|
||||
};
|
||||
}
|
||||
return isAllowed;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BufferizationOptions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -58,22 +81,7 @@ bool BufferizationOptions::isOpAllowed(Operation *op) const {
|
|||
if (!bufferizeFunctionBoundaries && isFuncBoundaryOp)
|
||||
return false;
|
||||
|
||||
// All other ops: Allow/disallow according to filter.
|
||||
bool isAllowed = !filterHasAllowRule();
|
||||
for (const OpFilterEntry &entry : opFilter) {
|
||||
bool filterResult = entry.fn(op);
|
||||
switch (entry.type) {
|
||||
case OpFilterEntry::ALLOW:
|
||||
isAllowed |= filterResult;
|
||||
break;
|
||||
case OpFilterEntry::DENY:
|
||||
if (filterResult)
|
||||
// DENY filter matches. This op is no allowed. (Even if other ALLOW
|
||||
// filters may match.)
|
||||
return false;
|
||||
};
|
||||
}
|
||||
return isAllowed;
|
||||
return opFilter.isOpAllowed(op);
|
||||
}
|
||||
|
||||
BufferizableOpInterface
|
||||
|
|
|
@ -194,7 +194,7 @@ struct OneShotBufferizePass
|
|||
opt.promoteBufferResultsToOutParams = promoteBufferResultsToOutParams;
|
||||
opt.unknownTypeConversion = parseLayoutMapOption(unknownTypeConversion);
|
||||
|
||||
BufferizationOptions::OpFilterEntry::FilterFn filterFn =
|
||||
OpFilter::Entry::FilterFn filterFn =
|
||||
[&](Operation *op) {
|
||||
// Filter may be specified via options.
|
||||
if (this->dialectFilter.hasValue())
|
||||
|
@ -204,7 +204,7 @@ struct OneShotBufferizePass
|
|||
// No filter specified: All other ops are allowed.
|
||||
return true;
|
||||
};
|
||||
opt.allowOperationInFilter(filterFn);
|
||||
opt.opFilter.allowOperation(filterFn);
|
||||
} else {
|
||||
opt = *options;
|
||||
}
|
||||
|
@ -242,7 +242,7 @@ struct BufferizationBufferizePass
|
|||
: public BufferizationBufferizeBase<BufferizationBufferizePass> {
|
||||
void runOnOperation() override {
|
||||
BufferizationOptions options = getPartialBufferizationOptions();
|
||||
options.allowDialectInFilter<BufferizationDialect>();
|
||||
options.opFilter.allowDialect<BufferizationDialect>();
|
||||
|
||||
if (failed(bufferizeOp(getOperation(), options)))
|
||||
signalPassFailure();
|
||||
|
|
|
@ -28,7 +28,7 @@ namespace {
|
|||
struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> {
|
||||
void runOnOperation() override {
|
||||
BufferizationOptions options = getPartialBufferizationOptions();
|
||||
options.allowDialectInFilter<linalg::LinalgDialect>();
|
||||
options.opFilter.allowDialect<linalg::LinalgDialect>();
|
||||
|
||||
if (failed(bufferizeOp(getOperation(), options)))
|
||||
signalPassFailure();
|
||||
|
|
|
@ -23,7 +23,7 @@ namespace {
|
|||
struct ShapeBufferizePass : public ShapeBufferizeBase<ShapeBufferizePass> {
|
||||
void runOnOperation() override {
|
||||
BufferizationOptions options = getPartialBufferizationOptions();
|
||||
options.allowDialectInFilter<shape::ShapeDialect>();
|
||||
options.opFilter.allowDialect<shape::ShapeDialect>();
|
||||
|
||||
if (failed(bufferizeOp(getOperation(), options)))
|
||||
signalPassFailure();
|
||||
|
|
|
@ -30,7 +30,7 @@ namespace {
|
|||
struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> {
|
||||
void runOnOperation() override {
|
||||
BufferizationOptions options = getPartialBufferizationOptions();
|
||||
options.allowDialectInFilter<tensor::TensorDialect>();
|
||||
options.opFilter.allowDialect<tensor::TensorDialect>();
|
||||
|
||||
if (failed(bufferizeOp(getOperation(), options)))
|
||||
signalPassFailure();
|
||||
|
|
|
@ -27,7 +27,7 @@ namespace {
|
|||
struct VectorBufferizePass : public VectorBufferizeBase<VectorBufferizePass> {
|
||||
void runOnOperation() override {
|
||||
BufferizationOptions options = getPartialBufferizationOptions();
|
||||
options.allowDialectInFilter<vector::VectorDialect>();
|
||||
options.opFilter.allowDialect<vector::VectorDialect>();
|
||||
|
||||
if (failed(bufferizeOp(getOperation(), options)))
|
||||
signalPassFailure();
|
||||
|
|
Loading…
Reference in New Issue