[mlir][bufferize][NFC] Remove BufferizationState

With the recent refactorings, this class is no longer needed. We can use BufferizationOptions in all places were BufferizationState was used.

Differential Revision: https://reviews.llvm.org/D127653
This commit is contained in:
Matthias Springer 2022-06-17 14:01:25 +02:00
parent c80c57674e
commit b55d55ecd9
17 changed files with 143 additions and 173 deletions

View File

@ -30,40 +30,38 @@ and with aggressive in-place bufferization.
One-Shot Bufferize is:
* **Monolithic**: A single MLIR pass does the entire
work, whereas the previous bufferization in MLIR was split across multiple
passes residing in different dialects. In One-Shot Bufferize,
`BufferizableOpInterface` implementations are spread across different dialects.
* **Monolithic**: A single MLIR pass does the entire work, whereas the
previous bufferization in MLIR was split across multiple passes residing in
different dialects. In One-Shot Bufferize, `BufferizableOpInterface`
implementations are spread across different dialects.
* A **whole-function at a time analysis**. In-place bufferization decisions are
made by analyzing SSA use-def chains on tensors. Op interface implementations
not only provide the rewrite logic from tensor ops to memref ops, but also
helper methods for One-Shot Bufferize's analysis to query information about an
op's bufferization/memory semantics.
* A **whole-function at a time analysis**. In-place bufferization decisions
are made by analyzing SSA use-def chains on tensors. Op interface
implementations not only provide the rewrite logic from tensor ops to memref
ops, but also helper methods for One-Shot Bufferize's analysis to query
information about an op's bufferization/memory semantics.
* **Extensible** via an op interface: All
ops that implement `BufferizableOpInterface` can be bufferized.
* **Extensible** via an op interface: All ops that implement
`BufferizableOpInterface` can be bufferized.
* **2-Pass**:
Bufferization is internally broken down into 2 steps: First, analyze the entire
IR and make bufferization decisions. Then, bufferize (rewrite) the IR. The
analysis has access to exact SSA use-def information. It incrementally builds
alias and equivalence sets and does not rely on a posteriori-alias analysis from
preallocated memory.
* **2-Pass**: Bufferization is internally broken down into 2 steps: First,
analyze the entire IR and make bufferization decisions. Then, bufferize
(rewrite) the IR. The analysis has access to exact SSA use-def information.
It incrementally builds alias and equivalence sets and does not rely on a
posteriori-alias analysis from preallocated memory.
* **Greedy**: Operations are analyzed one-by-one and it is
decided on the spot whether a tensor OpOperand must be copied or not. Heuristics
determine the order of analysis.
* **Greedy**: Operations are analyzed one-by-one and it is decided on the spot
whether a tensor OpOperand must be copied or not. Heuristics determine the
order of analysis.
* **Modular**: The current One-Shot Analysis
can be replaced with a different analysis. The result of the analysis are
queried by the bufferization via `BufferizationState`, in particular
`BufferizationState::isInPlace`. Any derived class of `BufferizationState` that
implements a small number virtual functions can serve as a custom analysis. It
is even possible to run One-Shot Bufferize without any analysis
(`AlwaysCopyBufferizationState`), in which case One-Shot Bufferize behaves
exactly like the old dialect conversion-based bufferization (i.e., copy every
buffer before writing to it).
* **Modular**: The current One-Shot Analysis can be replaced with a different
analysis. The result of the analysis are queried by the bufferization via
`AnalysisState`, in particular `AnalysisState::isInPlace`. Any derived class
of `AnalysisState` that implements a small number virtual functions can
serve as a custom analysis. It is even possible to run One-Shot Bufferize
without any analysis (`AlwaysCopyAnalysisState`), in which case One-Shot
Bufferize behaves exactly like the old dialect conversion-based
bufferization (i.e., copy every buffer before writing to it).
To reduce complexity, One-Shot Bufferize should be
[run after other transformations](https://llvm.discourse.group/t/rfc-linalg-on-tensors-update-and-comprehensive-bufferization-rfc/3373),

View File

@ -236,7 +236,7 @@ struct BufferizationOptions {
///
/// Note: Deactivating this flag can lead to incorrect bufferization results
/// when used incorrectly. This flag is useful with
/// `AlwaysCopyBufferizationState` which bufferizes all writing tensor
/// `AlwaysCopyAnalysisState` which bufferizes all writing tensor
/// OpOperands out-of-place.
bool enforceAliasingInvariants = true;
@ -464,33 +464,6 @@ private:
const BufferizationOptions &options;
};
/// BufferizationState provides helper functions for performing bufferization
/// rewrites and handling memref buffers.
struct BufferizationState {
BufferizationState(const BufferizationOptions &options) : options(options) {}
/// Lookup the buffer for the given value. If the value was not bufferized
/// yet, wrap it in a ToMemrefOp. Otherwise, it is the result of a ToTensorOp,
/// from which the memref operand is returned.
Value getBuffer(RewriterBase &rewriter, Value value);
/// Return the buffer type for a given Value (tensor) after bufferization.
///
/// Note: Op implementations should preferrably call `getBuffer()->getType()`.
/// This function should only be used if `getBuffer` cannot be used.
BaseMemRefType getBufferType(Value value) const;
/// Return a reference to the BufferizationOptions.
const BufferizationOptions &getOptions() const { return options; }
protected:
// BufferizationState should be passed as a reference.
BufferizationState(const BufferizationState &) = delete;
private:
const BufferizationOptions &options;
};
/// Create an AllocTensorOp for the given shaped value (memref or tensor).
/// If `copy` is set, the shaped value is copied. Otherwise, a tensor with
/// undefined contents is allocated.
@ -498,6 +471,18 @@ Value allocateTensorForShapedValue(OpBuilder &b, Location loc,
Value shapedValue, bool escape,
bool copy = true);
/// Lookup the buffer for the given value. If the value was not bufferized
/// yet, wrap it in a ToMemrefOp. Otherwise, it is the result of a ToTensorOp,
/// from which the memref operand is returned.
Value getBuffer(RewriterBase &rewriter, Value value,
const BufferizationOptions &options);
/// Return the buffer type for a given Value (tensor) after bufferization.
///
/// Note: Op implementations should preferrably call `getBuffer()->getType()`.
/// This function should only be used if `getBuffer` cannot be used.
BaseMemRefType getBufferType(Value value, const BufferizationOptions &options);
/// Replace an op with replacement values. The op is deleted. Tensor OpResults
/// must be replaced with memref values.
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op,

View File

@ -221,7 +221,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
InterfaceMethod<
/*desc=*/[{
Bufferize this op, i.e., rewrite it into a memref-based equivalent.
Buffers of tensor SSA values can be retrieved via `state.getBuffer`.
Buffers of tensor SSA values can be retrieved via `getBuffer`.
Uses of tensor results of the existing tensor op can be replaced with
`replaceOpWithBufferizedValues` or `replaceOpWithNewBufferizedOp`.
These two functions automatically handle the tensor-to-memref type
@ -233,12 +233,6 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
a) A buffer that aliases one of buffers in getAliasingOpOperand(r).
b) Or: A newly allocated buffer.
Regions of an op should be inlined into the new op instead of cloning
them. This is not only more efficient, but also necessary so that no
analysis results are lost. (Bufferization decisions are tracked via
OpOperand pointers and cloned ops have new OpOperands.) If regions are
cloned instead of inlined, additional buffer copies may be inserted.
This method will never be called on ops that do not have at least one
tensor operand/result.
@ -252,7 +246,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*retType=*/"LogicalResult",
/*methodName=*/"bufferize",
/*args=*/(ins "RewriterBase &":$rewriter,
"BufferizationState &":$state),
"const BufferizationOptions &":$options),
/*methodBody=*/"",
/*defaultImplementation=*/[{
llvm_unreachable("bufferize not implemented");

View File

@ -71,7 +71,8 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
let results = (outs AnyTensor:$result);
let extraClassDeclaration = [{
LogicalResult bufferize(RewriterBase &rewriter, BufferizationState &state);
LogicalResult bufferize(RewriterBase &rewriter,
const BufferizationOptions &options);
bool isMemoryWrite(OpResult opResult, const AnalysisState &state);
@ -242,7 +243,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
// results as not writable enforces a buffer copy and has the same effect.
LogicalResult bufferize(RewriterBase &rewriter,
BufferizationState &state) const {
const BufferizationOptions &options) const {
// to_tensor cannot be bufferized. However, other ops that are using
// to_tensor's result will eventually be bufferized. At that point, they
// will start using to_tensor's memref operand. Once all users of
@ -334,7 +335,7 @@ def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
}
LogicalResult bufferize(RewriterBase &rewriter,
BufferizationState &state);
const BufferizationOptions &options);
}];
let assemblyFormat = "$tensor attr-dict `:` type($memref)";

View File

@ -25,7 +25,6 @@ namespace mlir {
namespace bufferization {
class AnalysisState;
struct BufferizationState;
struct BufferizationOptions;
class OpFilter;

View File

@ -15,7 +15,6 @@ struct LogicalResult;
class ModuleOp;
namespace bufferization {
struct BufferizationState;
class OneShotAnalysisState;
struct OneShotBufferizationOptions;

View File

@ -23,7 +23,7 @@ struct ConstantOpInterface
: public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
arith::ConstantOp> {
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
const BufferizationOptions &options) const {
auto constantOp = cast<arith::ConstantOp>(op);
// Only ranked tensors are supported.
@ -38,7 +38,7 @@ struct ConstantOpInterface
// Create global memory segment and replace tensor with memref pointing to
// that memory segment.
FailureOr<memref::GlobalOp> globalOp =
getGlobalFor(constantOp, state.getOptions().bufferAlignment);
getGlobalFor(constantOp, options.bufferAlignment);
if (failed(globalOp))
return failure();
memref::GlobalOp globalMemref = globalOp.getValue();
@ -80,11 +80,11 @@ struct IndexCastOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
const BufferizationOptions &options) const {
auto castOp = cast<arith::IndexCastOp>(op);
auto resultTensorType = castOp.getType().cast<TensorType>();
Value source = state.getBuffer(rewriter, castOp.getIn());
Value source = getBuffer(rewriter, castOp.getIn(), options);
auto sourceType = source.getType().cast<BaseMemRefType>();
// Result type should have same layout and address space as the source type.
@ -132,7 +132,7 @@ struct SelectOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
const BufferizationOptions &options) const {
auto selectOp = cast<arith::SelectOp>(op);
Location loc = selectOp.getLoc();
@ -140,8 +140,8 @@ struct SelectOpInterface
// instead of its OpOperands. In the worst case, 2 copies are inserted at
// the moment (one for each tensor). When copying the op result, only one
// copy would be needed.
Value trueBuffer = state.getBuffer(rewriter, selectOp.getTrueValue());
Value falseBuffer = state.getBuffer(rewriter, selectOp.getFalseValue());
Value trueBuffer = getBuffer(rewriter, selectOp.getTrueValue(), options);
Value falseBuffer = getBuffer(rewriter, selectOp.getFalseValue(), options);
// The "true" and the "false" operands must have the same type. If the
// buffers have different types, they differ only in their layout map. Cast

View File

@ -477,7 +477,8 @@ static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
#endif
}
Value BufferizationState::getBuffer(RewriterBase &rewriter, Value value) {
Value bufferization::getBuffer(RewriterBase &rewriter, Value value,
const BufferizationOptions &options) {
auto tensorType = value.getType().dyn_cast<TensorType>();
assert(tensorType && "unexpected non-tensor type");
@ -488,21 +489,22 @@ Value BufferizationState::getBuffer(RewriterBase &rewriter, Value value) {
// Insert to_memref op.
OpBuilder::InsertionGuard g(rewriter);
setInsertionPointAfter(rewriter, value);
Type memrefType = getMemRefType(tensorType, getOptions());
Type memrefType = getMemRefType(tensorType, options);
ensureToMemrefOpIsValid(value, memrefType);
return rewriter.create<bufferization::ToMemrefOp>(value.getLoc(), memrefType,
value);
}
/// Return the buffer type for a given Value (tensor) after bufferization.
BaseMemRefType BufferizationState::getBufferType(Value value) const {
BaseMemRefType
bufferization::getBufferType(Value value, const BufferizationOptions &options) {
auto tensorType = value.getType().dyn_cast<TensorType>();
assert(tensorType && "unexpected non-tensor type");
if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
return toTensorOp.memref().getType().cast<BaseMemRefType>();
return getMemRefType(tensorType, getOptions());
return getMemRefType(tensorType, options);
}
void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,

View File

@ -150,7 +150,7 @@ void mlir::bufferization::populateDynamicDimSizes(
//===----------------------------------------------------------------------===//
LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
BufferizationState &state) {
const BufferizationOptions &options) {
OpBuilder::InsertionGuard g(rewriter);
Location loc = getLoc();
@ -163,7 +163,7 @@ LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
// Create buffer allocation.
Value copyBuffer;
if (copy())
copyBuffer = state.getBuffer(rewriter, copy());
copyBuffer = getBuffer(rewriter, copy(), options);
auto allocType =
MemRefType::get(getType().getShape(), getType().getElementType());
SmallVector<Value> dynamicDims = dynamicSizes();
@ -172,25 +172,24 @@ LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
populateDynamicDimSizes(rewriter, loc, copyBuffer, dynamicDims);
}
FailureOr<Value> alloc =
state.getOptions().createAlloc(rewriter, loc, allocType, dynamicDims);
options.createAlloc(rewriter, loc, allocType, dynamicDims);
if (failed(alloc))
return failure();
// Create memory copy (if any).
if (copy()) {
if (failed(
state.getOptions().createMemCpy(rewriter, loc, copyBuffer, *alloc)))
if (failed(options.createMemCpy(rewriter, loc, copyBuffer, *alloc)))
return failure();
}
// Should the buffer be deallocated?
AnalysisState analysisState(state.getOptions());
AnalysisState analysisState(options);
bool dealloc;
if (escape().hasValue()) {
dealloc = !*escape();
} else {
// No "escape" annotation found.
if (state.getOptions().createDeallocs) {
if (options.createDeallocs) {
// Perform an ad-hoc analysis.
dealloc = !analysisState.isTensorYielded(getResult());
} else {
@ -206,7 +205,7 @@ LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
return success();
rewriter.setInsertionPoint(rewriter.getInsertionBlock()->getTerminator());
if (failed(state.getOptions().createDealloc(rewriter, loc, *alloc)))
if (failed(options.createDealloc(rewriter, loc, *alloc)))
return failure();
return success();
}
@ -627,7 +626,7 @@ void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results,
}
LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter,
BufferizationState &state) {
const BufferizationOptions &options) {
// Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary.
(void)foldToMemrefToTensorPair(rewriter, *this);
// Note: The return value of `bufferize` indicates whether there was an error

View File

@ -401,7 +401,6 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
DenseSet<Operation *> erasedOps;
// Bufferize all ops.
BufferizationState bufferizationState(options);
BufferizationRewriter rewriter(op->getContext(), erasedOps, toMemrefOps,
worklist, options, opFilter);
for (unsigned i = 0; i < worklist.size(); ++i) {
@ -420,7 +419,7 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
continue;
// Bufferize the op.
rewriter.setInsertionPoint(op);
if (failed(bufferizableOp.bufferize(rewriter, bufferizationState)))
if (failed(bufferizableOp.bufferize(rewriter, options)))
return op->emitError("failed to bufferize op");
}
@ -433,7 +432,7 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
/// Check the result of bufferization. Return an error if an op was not
/// bufferized, unless partial bufferization is allowed.
if (bufferizationState.getOptions().allowUnknownOps)
if (options.allowUnknownOps)
return success();
for (Operation *op : worklist) {

View File

@ -258,7 +258,7 @@ struct CallOpInterface
/// All function arguments are writable. It is the responsibility of the
/// CallOp to insert buffer copies where necessary.
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
const BufferizationOptions &options) const {
func::CallOp callOp = cast<func::CallOp>(op);
unsigned numResults = callOp.getNumResults();
unsigned numOperands = callOp->getNumOperands();
@ -307,7 +307,7 @@ struct CallOpInterface
// Retrieve buffers for tensor operands.
Value buffer = newOperands[idx];
if (!buffer)
buffer = state.getBuffer(rewriter, opOperand.get());
buffer = getBuffer(rewriter, opOperand.get(), options);
// Caller / callee type mismatch is handled with a CastOp.
auto memRefType = funcType.getInput(idx);
@ -364,7 +364,7 @@ struct ReturnOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
const BufferizationOptions &options) const {
#ifndef NDEBUG
auto returnOp = cast<func::ReturnOp>(op);
assert(isa<FuncOp>(returnOp->getParentOp()) &&
@ -386,11 +386,9 @@ struct FuncOpInterface
/// All function bbArgs are writable unless they are explicitly marked as
/// read-only. Callers must insert copies when needed.
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
const BufferizationOptions &options) const {
auto funcOp = cast<FuncOp>(op);
FunctionType funcType = funcOp.getFunctionType();
const OneShotBufferizationOptions &options =
static_cast<const OneShotBufferizationOptions &>(state.getOptions());
// Construct the bufferized function type.
SmallVector<Type> argTypes;

View File

@ -429,7 +429,6 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
assert(options.bufferizeFunctionBoundaries &&
"expected that function boundary bufferization is activated");
IRRewriter rewriter(moduleOp.getContext());
BufferizationState bufferizationState(options);
// A list of functions in the order in which they are analyzed + bufferized.
SmallVector<func::FuncOp> orderedFuncOps;

View File

@ -20,11 +20,9 @@ using namespace mlir::bufferization;
namespace {
// TODO: Ops in the linalg dialect can directly implement this interface.
/// Generic conversion for any LinalgOp on tensors.
static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op,
BufferizationState &state) {
const BufferizationOptions &options) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(op);
@ -46,14 +44,14 @@ static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op,
newInputBuffers.push_back(opOperand->get());
continue;
}
newInputBuffers.push_back(state.getBuffer(rewriter, opOperand->get()));
newInputBuffers.push_back(getBuffer(rewriter, opOperand->get(), options));
}
// New output operands for the cloned op.
SmallVector<Value> newOutputBuffers;
for (OpResult opResult : op->getOpResults()) {
OpOperand *opOperand = op.getOutputOperand(opResult.getResultNumber());
Value resultBuffer = state.getBuffer(rewriter, opOperand->get());
Value resultBuffer = getBuffer(rewriter, opOperand->get(), options);
newOutputBuffers.push_back(resultBuffer);
}
@ -123,8 +121,8 @@ struct LinalgOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
return bufferizeLinalgOp(rewriter, cast<LinalgOp>(op), state);
const BufferizationOptions &options) const {
return bufferizeLinalgOp(rewriter, cast<LinalgOp>(op), options);
}
};

View File

@ -73,7 +73,7 @@ struct ExecuteRegionOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
const BufferizationOptions &options) const {
auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
// Compute new result types.
@ -81,7 +81,7 @@ struct ExecuteRegionOpInterface
for (Type type : executeRegionOp->getResultTypes()) {
if (auto tensorType = type.dyn_cast<TensorType>()) {
// TODO: Infer the result type instead of computing it.
newResultTypes.push_back(getMemRefType(tensorType, state.getOptions()));
newResultTypes.push_back(getMemRefType(tensorType, options));
} else {
newResultTypes.push_back(type);
}
@ -183,7 +183,7 @@ struct IfOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
const BufferizationOptions &options) const {
auto ifOp = cast<scf::IfOp>(op);
// Compute new types of the bufferized scf.if op.
@ -191,7 +191,7 @@ struct IfOpInterface
for (Type returnType : ifOp->getResultTypes()) {
if (auto tensorType = returnType.dyn_cast<TensorType>()) {
// TODO: Infer the result type instead of computing it.
newTypes.push_back(getMemRefType(tensorType, state.getOptions()));
newTypes.push_back(getMemRefType(tensorType, options));
} else {
newTypes.push_back(returnType);
}
@ -309,11 +309,11 @@ static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
/// given OpOperands. If an operand is not a tensor, return the original value.
static SmallVector<Value> getBuffers(RewriterBase &rewriter,
MutableArrayRef<OpOperand> operands,
BufferizationState &state) {
const BufferizationOptions &options) {
SmallVector<Value> result;
for (OpOperand &opOperand : operands) {
if (opOperand.get().getType().isa<TensorType>()) {
Value resultBuffer = state.getBuffer(rewriter, opOperand.get());
Value resultBuffer = getBuffer(rewriter, opOperand.get(), options);
result.push_back(resultBuffer);
} else {
result.push_back(opOperand.get());
@ -325,10 +325,11 @@ static SmallVector<Value> getBuffers(RewriterBase &rewriter,
/// Helper function for loop bufferization. Compute the buffer that should be
/// yielded from a loop block (loop body or loop condition).
static Value getYieldedBuffer(RewriterBase &rewriter, Value tensor,
BaseMemRefType type, BufferizationState &state) {
BaseMemRefType type,
const BufferizationOptions &options) {
assert(tensor.getType().isa<TensorType>() && "expected tensor");
ensureToMemrefOpIsValid(tensor, type);
Value yieldedVal = state.getBuffer(rewriter, tensor);
Value yieldedVal = getBuffer(rewriter, tensor, options);
return castBuffer(rewriter, yieldedVal, type);
}
@ -352,12 +353,12 @@ convertTensorValues(ValueRange values, const DenseSet<int64_t> &tensorIndices,
SmallVector<Value> getYieldedValues(RewriterBase &rewriter, ValueRange values,
TypeRange bufferizedTypes,
const DenseSet<int64_t> &tensorIndices,
BufferizationState &state) {
const BufferizationOptions &options) {
return convertTensorValues(
values, tensorIndices, [&](Value val, int64_t index) {
return getYieldedBuffer(rewriter, val,
bufferizedTypes[index].cast<BaseMemRefType>(),
state);
options);
});
}
@ -472,7 +473,7 @@ struct ForOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
const BufferizationOptions &options) const {
auto forOp = cast<scf::ForOp>(op);
Block *oldLoopBody = &forOp.getLoopBody().front();
@ -482,7 +483,7 @@ struct ForOpInterface
// The new memref init_args of the loop.
SmallVector<Value> initArgs =
getBuffers(rewriter, forOp.getIterOpOperands(), state);
getBuffers(rewriter, forOp.getIterOpOperands(), options);
// Construct a new scf.for op with memref instead of tensor values.
auto newForOp = rewriter.create<scf::ForOp>(
@ -511,7 +512,7 @@ struct ForOpInterface
auto yieldOp = cast<scf::YieldOp>(loopBody->getTerminator());
rewriter.setInsertionPoint(yieldOp);
SmallVector<Value> yieldValues = getYieldedValues(
rewriter, yieldOp.getResults(), initArgsTypes, indices, state);
rewriter, yieldOp.getResults(), initArgsTypes, indices, options);
yieldOp.getResultsMutable().assign(yieldValues);
// Replace loop results.
@ -704,7 +705,7 @@ struct WhileOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
const BufferizationOptions &options) const {
auto whileOp = cast<scf::WhileOp>(op);
assert(whileOp.getBefore().getBlocks().size() == 1 &&
@ -722,12 +723,12 @@ struct WhileOpInterface
// The new memref init_args of the loop.
SmallVector<Value> initArgs =
getBuffers(rewriter, whileOp->getOpOperands(), state);
getBuffers(rewriter, whileOp->getOpOperands(), options);
// The result types of a WhileOp are the same as the "after" bbArg types.
SmallVector<Type> argsTypesAfter = llvm::to_vector(
llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) {
return state.getBufferType(bbArg).cast<Type>();
return getBufferType(bbArg, options).cast<Type>();
}));
// Construct a new scf.while op with memref instead of tensor values.
@ -761,7 +762,7 @@ struct WhileOpInterface
// TODO: This could be relaxed for better bufferization results.
SmallVector<Value> newConditionArgs =
getYieldedValues(rewriter, newConditionOp.getArgs(), argsTypesAfter,
indicesAfter, state);
indicesAfter, options);
newConditionOp.getArgsMutable().assign(newConditionArgs);
// Set up new iter_args and move the loop body block to the new op.
@ -780,7 +781,7 @@ struct WhileOpInterface
// TODO: This could be relaxed for better bufferization results.
SmallVector<Value> newYieldValues =
getYieldedValues(rewriter, newYieldOp.getResults(), argsTypesBefore,
indicesBefore, state);
indicesBefore, options);
newYieldOp.getResultsMutable().assign(newYieldValues);
// Replace loop results.
@ -866,7 +867,7 @@ struct YieldOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
const BufferizationOptions &options) const {
auto yieldOp = cast<scf::YieldOp>(op);
if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::ForOp, scf::WhileOp>(
yieldOp->getParentOp()))
@ -954,7 +955,7 @@ struct ForeachThreadOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &b,
BufferizationState &state) const {
const BufferizationOptions &options) const {
OpBuilder::InsertionGuard g(b);
auto foreachThreadOp = cast<ForeachThreadOp>(op);
@ -966,7 +967,7 @@ struct ForeachThreadOpInterface
// Insert copies right before the PerformConcurrentlyOp terminator. They
// should not be inside terminator (which would be the default insertion
// point).
Value buffer = state.getBuffer(b, insertDest->get());
Value buffer = getBuffer(b, insertDest->get(), options);
newResults.push_back(buffer);
}
@ -991,8 +992,7 @@ struct ForeachThreadOpInterface
performConcurrentlyOp.walk([&](ParallelInsertSliceOp insertOp) {
Location loc = insertOp.getLoc();
Type srcType = getMemRefType(
insertOp.getSource().getType().cast<RankedTensorType>(),
state.getOptions());
insertOp.getSource().getType().cast<RankedTensorType>(), options);
// ParallelInsertSliceOp bufferizes to a copy.
auto srcMemref = b.create<bufferization::ToMemrefOp>(
loc, srcType, insertOp.getSource());
@ -1001,8 +1001,8 @@ struct ForeachThreadOpInterface
loc, destMemref, insertOp.getMixedOffsets(),
insertOp.getMixedSizes(), insertOp.getMixedStrides());
// This memcpy will fold away if everything bufferizes in-place.
if (failed(state.getOptions().createMemCpy(b, insertOp.getLoc(),
srcMemref, subview)))
if (failed(options.createMemCpy(b, insertOp.getLoc(), srcMemref,
subview)))
return WalkResult::interrupt();
b.eraseOp(insertOp);
return WalkResult::advance();
@ -1022,7 +1022,7 @@ struct PerformConcurrentlyOpInterface
: public BufferizableOpInterface::ExternalModel<
PerformConcurrentlyOpInterface, PerformConcurrentlyOp> {
LogicalResult bufferize(Operation *op, RewriterBase &b,
BufferizationState &state) const {
const BufferizationOptions &options) const {
llvm_unreachable("op does not have any tensor OpOperands / OpResults");
return failure();
}
@ -1110,7 +1110,7 @@ struct ParallelInsertSliceOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &b,
BufferizationState &state) const {
const BufferizationOptions &options) const {
// Will be bufferized as part of ForeachThreadOp.
return failure();
}

View File

@ -59,7 +59,7 @@ struct AssumingOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
const BufferizationOptions &options) const {
auto assumingOp = cast<shape::AssumingOp>(op);
// Compute new result types.
@ -67,7 +67,7 @@ struct AssumingOpInterface
for (Type type : assumingOp->getResultTypes()) {
if (auto tensorType = type.dyn_cast<TensorType>()) {
// TODO: Infer the result type instead of computing it.
newResultTypes.push_back(getMemRefType(tensorType, state.getOptions()));
newResultTypes.push_back(getMemRefType(tensorType, options));
} else {
newResultTypes.push_back(type);
}
@ -152,7 +152,7 @@ struct AssumingYieldOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
const BufferizationOptions &options) const {
// Op is bufferized as part of AssumingOp.
return failure();
}

View File

@ -48,11 +48,11 @@ struct CastOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
const BufferizationOptions &options) const {
auto castOp = cast<tensor::CastOp>(op);
// The result buffer still has the old (pre-cast) type.
Value resultBuffer = state.getBuffer(rewriter, castOp.source());
Value resultBuffer = getBuffer(rewriter, castOp.source(), options);
auto sourceMemRefType = resultBuffer.getType().cast<BaseMemRefType>();
Attribute memorySpace = sourceMemRefType.getMemorySpace();
TensorType resultTensorType =
@ -64,8 +64,8 @@ struct CastOpInterface
layout = rankedMemRefType.getLayout();
// Compute the new memref type.
Type resultMemRefType = getMemRefType(resultTensorType, state.getOptions(),
layout, memorySpace);
Type resultMemRefType =
getMemRefType(resultTensorType, options, layout, memorySpace);
// Replace the op with a memref.cast.
assert(memref::CastOp::areCastCompatible(resultBuffer.getType(),
@ -105,10 +105,10 @@ struct CollapseShapeOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
const BufferizationOptions &options) const {
auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
RankedTensorType tensorResultType = collapseShapeOp.getResultType();
Value buffer = state.getBuffer(rewriter, collapseShapeOp.src());
Value buffer = getBuffer(rewriter, collapseShapeOp.src(), options);
auto bufferType = buffer.getType().cast<MemRefType>();
if (tensorResultType.getRank() == 0) {
@ -146,7 +146,7 @@ struct CollapseShapeOpInterface
bufferType, collapseShapeOp.getReassociationIndices());
if (!canBeCollapsed) {
// TODO: Create alloc_tensor ops during TensorCopyInsertion.
AnalysisState analysisState(state.getOptions());
AnalysisState analysisState(options);
Value tensorAlloc = allocateTensorForShapedValue(
rewriter, op->getLoc(), collapseShapeOp.src(),
analysisState.isTensorYielded(collapseShapeOp.result()));
@ -185,9 +185,9 @@ struct DimOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
const BufferizationOptions &options) const {
auto dimOp = cast<tensor::DimOp>(op);
auto v = state.getBuffer(rewriter, dimOp.source());
auto v = getBuffer(rewriter, dimOp.source(), options);
replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v, dimOp.index());
return success();
}
@ -220,10 +220,10 @@ struct ExpandShapeOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
const BufferizationOptions &options) const {
auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
auto tensorResultType = expandShapeOp.getResultType();
auto buffer = state.getBuffer(rewriter, expandShapeOp.src());
auto buffer = getBuffer(rewriter, expandShapeOp.src(), options);
// Memref result type is inferred by the builder based on reassociation
// indices and result shape.
@ -261,13 +261,13 @@ struct ExtractSliceOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
const BufferizationOptions &options) const {
auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
Location loc = extractSliceOp.getLoc();
// Even if this op was decided to bufferize out-of-place, do not insert the
// buffer copy yet. This is done later in this function.
auto srcMemref = state.getBuffer(rewriter, extractSliceOp.source());
auto srcMemref = getBuffer(rewriter, extractSliceOp.source(), options);
auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
auto dstTensorType =
extractSliceOp.result().getType().cast<RankedTensorType>();
@ -319,9 +319,9 @@ struct ExtractOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
const BufferizationOptions &options) const {
auto extractOp = cast<tensor::ExtractOp>(op);
Value srcMemref = state.getBuffer(rewriter, extractOp.tensor());
Value srcMemref = getBuffer(rewriter, extractOp.tensor(), options);
replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref,
extractOp.indices());
return success();
@ -355,7 +355,7 @@ struct FromElementsOpInterface
: public BufferizableOpInterface::ExternalModel<FromElementsOpInterface,
tensor::FromElementsOp> {
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
const BufferizationOptions &options) const {
auto fromElementsOp = cast<tensor::FromElementsOp>(op);
// Allocate a buffer for the result.
@ -363,7 +363,7 @@ struct FromElementsOpInterface
auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
auto shape = tensorType.getShape();
// TODO: Create alloc_tensor ops during TensorCopyInsertion.
AnalysisState analysisState(state.getOptions());
AnalysisState analysisState(options);
Value tensorAlloc = allocateTensorForShapedValue(
rewriter, loc, fromElementsOp.result(),
analysisState.isTensorYielded(fromElementsOp.result()),
@ -410,13 +410,13 @@ struct GenerateOpInterface
: public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
tensor::GenerateOp> {
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
const BufferizationOptions &options) const {
auto generateOp = cast<tensor::GenerateOp>(op);
auto tensorType = generateOp.getType().cast<RankedTensorType>();
// Allocate memory.
Location loc = op->getLoc();
// TODO: Create alloc_tensor ops during TensorCopyInsertion.
AnalysisState analysisState(state.getOptions());
AnalysisState analysisState(options);
Value tensorAlloc = allocateTensorForShapedValue(
rewriter, loc, generateOp.result(),
analysisState.isTensorYielded(generateOp.result()),
@ -493,9 +493,9 @@ struct InsertOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
const BufferizationOptions &options) const {
auto insertOp = cast<tensor::InsertOp>(op);
Value destMemref = state.getBuffer(rewriter, insertOp.dest());
Value destMemref = getBuffer(rewriter, insertOp.dest(), options);
rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.scalar(),
destMemref, insertOp.indices());
replaceOpWithBufferizedValues(rewriter, op, destMemref);
@ -645,7 +645,7 @@ struct InsertSliceOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
const BufferizationOptions &options) const {
// insert_slice ops arise from tiling and bufferizing them out-of-place is
// generally a deal breaker. When used with loops, this ends up cloning the
// whole tensor on every single iteration and is a symptom of a
@ -653,7 +653,7 @@ struct InsertSliceOpInterface
// TODO: be very loud about it or even consider failing the pass.
auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
Location loc = insertSliceOp.getLoc();
Value dstMemref = state.getBuffer(rewriter, insertSliceOp.dest());
Value dstMemref = getBuffer(rewriter, insertSliceOp.dest(), options);
// Expand offsets, sizes and strides to the full rank to handle the
// rank-reducing case.
@ -681,9 +681,8 @@ struct InsertSliceOpInterface
// Copy tensor. If this tensor.insert_slice has a matching
// tensor.extract_slice, the copy operation will eventually fold away.
auto srcMemref = state.getBuffer(rewriter, insertSliceOp.source());
if (failed(
state.getOptions().createMemCpy(rewriter, loc, srcMemref, subView)))
auto srcMemref = getBuffer(rewriter, insertSliceOp.source(), options);
if (failed(options.createMemCpy(rewriter, loc, srcMemref, subView)))
return failure();
replaceOpWithBufferizedValues(rewriter, op, dstMemref);
@ -711,9 +710,9 @@ struct RankOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
const BufferizationOptions &options) const {
auto rankOp = cast<tensor::RankOp>(op);
auto v = state.getBuffer(rewriter, rankOp.tensor());
auto v = getBuffer(rewriter, rankOp.tensor(), options);
replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
v);
return success();
@ -747,12 +746,12 @@ struct ReshapeOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
const BufferizationOptions &options) const {
auto reshapeOp = cast<tensor::ReshapeOp>(op);
Value srcBuffer = state.getBuffer(rewriter, reshapeOp.source());
Value shapeBuffer = state.getBuffer(rewriter, reshapeOp.shape());
Value srcBuffer = getBuffer(rewriter, reshapeOp.source(), options);
Value shapeBuffer = getBuffer(rewriter, reshapeOp.shape(), options);
auto resultTensorType = reshapeOp.getResult().getType().cast<TensorType>();
auto resultMemRefType = getMemRefType(resultTensorType, state.getOptions());
auto resultMemRefType = getMemRefType(resultTensorType, options);
replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
rewriter, op, resultMemRefType, srcBuffer, shapeBuffer);
return success();

View File

@ -46,11 +46,11 @@ struct TransferReadOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
const BufferizationOptions &options) const {
auto readOp = cast<vector::TransferReadOp>(op);
assert(readOp.getShapedType().isa<TensorType>() &&
"only tensor types expected");
Value buffer = state.getBuffer(rewriter, readOp.getSource());
Value buffer = getBuffer(rewriter, readOp.getSource(), options);
replaceOpWithNewBufferizedOp<vector::TransferReadOp>(
rewriter, readOp, readOp.getVectorType(), buffer, readOp.getIndices(),
readOp.getPermutationMap(), readOp.getPadding(), readOp.getMask(),
@ -91,13 +91,13 @@ struct TransferWriteOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
const BufferizationOptions &options) const {
auto writeOp = cast<vector::TransferWriteOp>(op);
assert(writeOp.getShapedType().isa<TensorType>() &&
"only tensor types expected");
// Create a new transfer_write on buffer that doesn't have a return value.
Value resultBuffer = state.getBuffer(rewriter, writeOp.getSource());
Value resultBuffer = getBuffer(rewriter, writeOp.getSource(), options);
rewriter.create<vector::TransferWriteOp>(
writeOp.getLoc(), writeOp.getVector(), resultBuffer,
writeOp.getIndices(), writeOp.getPermutationMapAttr(),