forked from OSchip/llvm-project
[mlir][bufferize][NFC] Remove BufferizationState
With the recent refactorings, this class is no longer needed. We can use BufferizationOptions in all places were BufferizationState was used. Differential Revision: https://reviews.llvm.org/D127653
This commit is contained in:
parent
c80c57674e
commit
b55d55ecd9
|
@ -30,40 +30,38 @@ and with aggressive in-place bufferization.
|
|||
|
||||
One-Shot Bufferize is:
|
||||
|
||||
* **Monolithic**: A single MLIR pass does the entire
|
||||
work, whereas the previous bufferization in MLIR was split across multiple
|
||||
passes residing in different dialects. In One-Shot Bufferize,
|
||||
`BufferizableOpInterface` implementations are spread across different dialects.
|
||||
* **Monolithic**: A single MLIR pass does the entire work, whereas the
|
||||
previous bufferization in MLIR was split across multiple passes residing in
|
||||
different dialects. In One-Shot Bufferize, `BufferizableOpInterface`
|
||||
implementations are spread across different dialects.
|
||||
|
||||
* A **whole-function at a time analysis**. In-place bufferization decisions are
|
||||
made by analyzing SSA use-def chains on tensors. Op interface implementations
|
||||
not only provide the rewrite logic from tensor ops to memref ops, but also
|
||||
helper methods for One-Shot Bufferize's analysis to query information about an
|
||||
op's bufferization/memory semantics.
|
||||
* A **whole-function at a time analysis**. In-place bufferization decisions
|
||||
are made by analyzing SSA use-def chains on tensors. Op interface
|
||||
implementations not only provide the rewrite logic from tensor ops to memref
|
||||
ops, but also helper methods for One-Shot Bufferize's analysis to query
|
||||
information about an op's bufferization/memory semantics.
|
||||
|
||||
* **Extensible** via an op interface: All
|
||||
ops that implement `BufferizableOpInterface` can be bufferized.
|
||||
* **Extensible** via an op interface: All ops that implement
|
||||
`BufferizableOpInterface` can be bufferized.
|
||||
|
||||
* **2-Pass**:
|
||||
Bufferization is internally broken down into 2 steps: First, analyze the entire
|
||||
IR and make bufferization decisions. Then, bufferize (rewrite) the IR. The
|
||||
analysis has access to exact SSA use-def information. It incrementally builds
|
||||
alias and equivalence sets and does not rely on a posteriori-alias analysis from
|
||||
preallocated memory.
|
||||
* **2-Pass**: Bufferization is internally broken down into 2 steps: First,
|
||||
analyze the entire IR and make bufferization decisions. Then, bufferize
|
||||
(rewrite) the IR. The analysis has access to exact SSA use-def information.
|
||||
It incrementally builds alias and equivalence sets and does not rely on a
|
||||
posteriori-alias analysis from preallocated memory.
|
||||
|
||||
* **Greedy**: Operations are analyzed one-by-one and it is
|
||||
decided on the spot whether a tensor OpOperand must be copied or not. Heuristics
|
||||
determine the order of analysis.
|
||||
* **Greedy**: Operations are analyzed one-by-one and it is decided on the spot
|
||||
whether a tensor OpOperand must be copied or not. Heuristics determine the
|
||||
order of analysis.
|
||||
|
||||
* **Modular**: The current One-Shot Analysis
|
||||
can be replaced with a different analysis. The result of the analysis are
|
||||
queried by the bufferization via `BufferizationState`, in particular
|
||||
`BufferizationState::isInPlace`. Any derived class of `BufferizationState` that
|
||||
implements a small number virtual functions can serve as a custom analysis. It
|
||||
is even possible to run One-Shot Bufferize without any analysis
|
||||
(`AlwaysCopyBufferizationState`), in which case One-Shot Bufferize behaves
|
||||
exactly like the old dialect conversion-based bufferization (i.e., copy every
|
||||
buffer before writing to it).
|
||||
* **Modular**: The current One-Shot Analysis can be replaced with a different
|
||||
analysis. The result of the analysis are queried by the bufferization via
|
||||
`AnalysisState`, in particular `AnalysisState::isInPlace`. Any derived class
|
||||
of `AnalysisState` that implements a small number virtual functions can
|
||||
serve as a custom analysis. It is even possible to run One-Shot Bufferize
|
||||
without any analysis (`AlwaysCopyAnalysisState`), in which case One-Shot
|
||||
Bufferize behaves exactly like the old dialect conversion-based
|
||||
bufferization (i.e., copy every buffer before writing to it).
|
||||
|
||||
To reduce complexity, One-Shot Bufferize should be
|
||||
[run after other transformations](https://llvm.discourse.group/t/rfc-linalg-on-tensors-update-and-comprehensive-bufferization-rfc/3373),
|
||||
|
|
|
@ -236,7 +236,7 @@ struct BufferizationOptions {
|
|||
///
|
||||
/// Note: Deactivating this flag can lead to incorrect bufferization results
|
||||
/// when used incorrectly. This flag is useful with
|
||||
/// `AlwaysCopyBufferizationState` which bufferizes all writing tensor
|
||||
/// `AlwaysCopyAnalysisState` which bufferizes all writing tensor
|
||||
/// OpOperands out-of-place.
|
||||
bool enforceAliasingInvariants = true;
|
||||
|
||||
|
@ -464,33 +464,6 @@ private:
|
|||
const BufferizationOptions &options;
|
||||
};
|
||||
|
||||
/// BufferizationState provides helper functions for performing bufferization
|
||||
/// rewrites and handling memref buffers.
|
||||
struct BufferizationState {
|
||||
BufferizationState(const BufferizationOptions &options) : options(options) {}
|
||||
|
||||
/// Lookup the buffer for the given value. If the value was not bufferized
|
||||
/// yet, wrap it in a ToMemrefOp. Otherwise, it is the result of a ToTensorOp,
|
||||
/// from which the memref operand is returned.
|
||||
Value getBuffer(RewriterBase &rewriter, Value value);
|
||||
|
||||
/// Return the buffer type for a given Value (tensor) after bufferization.
|
||||
///
|
||||
/// Note: Op implementations should preferrably call `getBuffer()->getType()`.
|
||||
/// This function should only be used if `getBuffer` cannot be used.
|
||||
BaseMemRefType getBufferType(Value value) const;
|
||||
|
||||
/// Return a reference to the BufferizationOptions.
|
||||
const BufferizationOptions &getOptions() const { return options; }
|
||||
|
||||
protected:
|
||||
// BufferizationState should be passed as a reference.
|
||||
BufferizationState(const BufferizationState &) = delete;
|
||||
|
||||
private:
|
||||
const BufferizationOptions &options;
|
||||
};
|
||||
|
||||
/// Create an AllocTensorOp for the given shaped value (memref or tensor).
|
||||
/// If `copy` is set, the shaped value is copied. Otherwise, a tensor with
|
||||
/// undefined contents is allocated.
|
||||
|
@ -498,6 +471,18 @@ Value allocateTensorForShapedValue(OpBuilder &b, Location loc,
|
|||
Value shapedValue, bool escape,
|
||||
bool copy = true);
|
||||
|
||||
/// Lookup the buffer for the given value. If the value was not bufferized
|
||||
/// yet, wrap it in a ToMemrefOp. Otherwise, it is the result of a ToTensorOp,
|
||||
/// from which the memref operand is returned.
|
||||
Value getBuffer(RewriterBase &rewriter, Value value,
|
||||
const BufferizationOptions &options);
|
||||
|
||||
/// Return the buffer type for a given Value (tensor) after bufferization.
|
||||
///
|
||||
/// Note: Op implementations should preferrably call `getBuffer()->getType()`.
|
||||
/// This function should only be used if `getBuffer` cannot be used.
|
||||
BaseMemRefType getBufferType(Value value, const BufferizationOptions &options);
|
||||
|
||||
/// Replace an op with replacement values. The op is deleted. Tensor OpResults
|
||||
/// must be replaced with memref values.
|
||||
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op,
|
||||
|
|
|
@ -221,7 +221,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
|
|||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Bufferize this op, i.e., rewrite it into a memref-based equivalent.
|
||||
Buffers of tensor SSA values can be retrieved via `state.getBuffer`.
|
||||
Buffers of tensor SSA values can be retrieved via `getBuffer`.
|
||||
Uses of tensor results of the existing tensor op can be replaced with
|
||||
`replaceOpWithBufferizedValues` or `replaceOpWithNewBufferizedOp`.
|
||||
These two functions automatically handle the tensor-to-memref type
|
||||
|
@ -233,12 +233,6 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
|
|||
a) A buffer that aliases one of buffers in getAliasingOpOperand(r).
|
||||
b) Or: A newly allocated buffer.
|
||||
|
||||
Regions of an op should be inlined into the new op instead of cloning
|
||||
them. This is not only more efficient, but also necessary so that no
|
||||
analysis results are lost. (Bufferization decisions are tracked via
|
||||
OpOperand pointers and cloned ops have new OpOperands.) If regions are
|
||||
cloned instead of inlined, additional buffer copies may be inserted.
|
||||
|
||||
This method will never be called on ops that do not have at least one
|
||||
tensor operand/result.
|
||||
|
||||
|
@ -252,7 +246,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
|
|||
/*retType=*/"LogicalResult",
|
||||
/*methodName=*/"bufferize",
|
||||
/*args=*/(ins "RewriterBase &":$rewriter,
|
||||
"BufferizationState &":$state),
|
||||
"const BufferizationOptions &":$options),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
llvm_unreachable("bufferize not implemented");
|
||||
|
|
|
@ -71,7 +71,8 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
|
|||
let results = (outs AnyTensor:$result);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
LogicalResult bufferize(RewriterBase &rewriter, BufferizationState &state);
|
||||
LogicalResult bufferize(RewriterBase &rewriter,
|
||||
const BufferizationOptions &options);
|
||||
|
||||
bool isMemoryWrite(OpResult opResult, const AnalysisState &state);
|
||||
|
||||
|
@ -242,7 +243,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
|
|||
// results as not writable enforces a buffer copy and has the same effect.
|
||||
|
||||
LogicalResult bufferize(RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationOptions &options) const {
|
||||
// to_tensor cannot be bufferized. However, other ops that are using
|
||||
// to_tensor's result will eventually be bufferized. At that point, they
|
||||
// will start using to_tensor's memref operand. Once all users of
|
||||
|
@ -334,7 +335,7 @@ def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
|
|||
}
|
||||
|
||||
LogicalResult bufferize(RewriterBase &rewriter,
|
||||
BufferizationState &state);
|
||||
const BufferizationOptions &options);
|
||||
}];
|
||||
|
||||
let assemblyFormat = "$tensor attr-dict `:` type($memref)";
|
||||
|
|
|
@ -25,7 +25,6 @@ namespace mlir {
|
|||
namespace bufferization {
|
||||
|
||||
class AnalysisState;
|
||||
struct BufferizationState;
|
||||
struct BufferizationOptions;
|
||||
class OpFilter;
|
||||
|
||||
|
|
|
@ -15,7 +15,6 @@ struct LogicalResult;
|
|||
class ModuleOp;
|
||||
|
||||
namespace bufferization {
|
||||
struct BufferizationState;
|
||||
class OneShotAnalysisState;
|
||||
struct OneShotBufferizationOptions;
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ struct ConstantOpInterface
|
|||
: public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
|
||||
arith::ConstantOp> {
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationOptions &options) const {
|
||||
auto constantOp = cast<arith::ConstantOp>(op);
|
||||
|
||||
// Only ranked tensors are supported.
|
||||
|
@ -38,7 +38,7 @@ struct ConstantOpInterface
|
|||
// Create global memory segment and replace tensor with memref pointing to
|
||||
// that memory segment.
|
||||
FailureOr<memref::GlobalOp> globalOp =
|
||||
getGlobalFor(constantOp, state.getOptions().bufferAlignment);
|
||||
getGlobalFor(constantOp, options.bufferAlignment);
|
||||
if (failed(globalOp))
|
||||
return failure();
|
||||
memref::GlobalOp globalMemref = globalOp.getValue();
|
||||
|
@ -80,11 +80,11 @@ struct IndexCastOpInterface
|
|||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationOptions &options) const {
|
||||
auto castOp = cast<arith::IndexCastOp>(op);
|
||||
auto resultTensorType = castOp.getType().cast<TensorType>();
|
||||
|
||||
Value source = state.getBuffer(rewriter, castOp.getIn());
|
||||
Value source = getBuffer(rewriter, castOp.getIn(), options);
|
||||
auto sourceType = source.getType().cast<BaseMemRefType>();
|
||||
|
||||
// Result type should have same layout and address space as the source type.
|
||||
|
@ -132,7 +132,7 @@ struct SelectOpInterface
|
|||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationOptions &options) const {
|
||||
auto selectOp = cast<arith::SelectOp>(op);
|
||||
Location loc = selectOp.getLoc();
|
||||
|
||||
|
@ -140,8 +140,8 @@ struct SelectOpInterface
|
|||
// instead of its OpOperands. In the worst case, 2 copies are inserted at
|
||||
// the moment (one for each tensor). When copying the op result, only one
|
||||
// copy would be needed.
|
||||
Value trueBuffer = state.getBuffer(rewriter, selectOp.getTrueValue());
|
||||
Value falseBuffer = state.getBuffer(rewriter, selectOp.getFalseValue());
|
||||
Value trueBuffer = getBuffer(rewriter, selectOp.getTrueValue(), options);
|
||||
Value falseBuffer = getBuffer(rewriter, selectOp.getFalseValue(), options);
|
||||
|
||||
// The "true" and the "false" operands must have the same type. If the
|
||||
// buffers have different types, they differ only in their layout map. Cast
|
||||
|
|
|
@ -477,7 +477,8 @@ static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
|
|||
#endif
|
||||
}
|
||||
|
||||
Value BufferizationState::getBuffer(RewriterBase &rewriter, Value value) {
|
||||
Value bufferization::getBuffer(RewriterBase &rewriter, Value value,
|
||||
const BufferizationOptions &options) {
|
||||
auto tensorType = value.getType().dyn_cast<TensorType>();
|
||||
assert(tensorType && "unexpected non-tensor type");
|
||||
|
||||
|
@ -488,21 +489,22 @@ Value BufferizationState::getBuffer(RewriterBase &rewriter, Value value) {
|
|||
// Insert to_memref op.
|
||||
OpBuilder::InsertionGuard g(rewriter);
|
||||
setInsertionPointAfter(rewriter, value);
|
||||
Type memrefType = getMemRefType(tensorType, getOptions());
|
||||
Type memrefType = getMemRefType(tensorType, options);
|
||||
ensureToMemrefOpIsValid(value, memrefType);
|
||||
return rewriter.create<bufferization::ToMemrefOp>(value.getLoc(), memrefType,
|
||||
value);
|
||||
}
|
||||
|
||||
/// Return the buffer type for a given Value (tensor) after bufferization.
|
||||
BaseMemRefType BufferizationState::getBufferType(Value value) const {
|
||||
BaseMemRefType
|
||||
bufferization::getBufferType(Value value, const BufferizationOptions &options) {
|
||||
auto tensorType = value.getType().dyn_cast<TensorType>();
|
||||
assert(tensorType && "unexpected non-tensor type");
|
||||
|
||||
if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
|
||||
return toTensorOp.memref().getType().cast<BaseMemRefType>();
|
||||
|
||||
return getMemRefType(tensorType, getOptions());
|
||||
return getMemRefType(tensorType, options);
|
||||
}
|
||||
|
||||
void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
|
||||
|
|
|
@ -150,7 +150,7 @@ void mlir::bufferization::populateDynamicDimSizes(
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
|
||||
BufferizationState &state) {
|
||||
const BufferizationOptions &options) {
|
||||
OpBuilder::InsertionGuard g(rewriter);
|
||||
Location loc = getLoc();
|
||||
|
||||
|
@ -163,7 +163,7 @@ LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
|
|||
// Create buffer allocation.
|
||||
Value copyBuffer;
|
||||
if (copy())
|
||||
copyBuffer = state.getBuffer(rewriter, copy());
|
||||
copyBuffer = getBuffer(rewriter, copy(), options);
|
||||
auto allocType =
|
||||
MemRefType::get(getType().getShape(), getType().getElementType());
|
||||
SmallVector<Value> dynamicDims = dynamicSizes();
|
||||
|
@ -172,25 +172,24 @@ LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
|
|||
populateDynamicDimSizes(rewriter, loc, copyBuffer, dynamicDims);
|
||||
}
|
||||
FailureOr<Value> alloc =
|
||||
state.getOptions().createAlloc(rewriter, loc, allocType, dynamicDims);
|
||||
options.createAlloc(rewriter, loc, allocType, dynamicDims);
|
||||
if (failed(alloc))
|
||||
return failure();
|
||||
|
||||
// Create memory copy (if any).
|
||||
if (copy()) {
|
||||
if (failed(
|
||||
state.getOptions().createMemCpy(rewriter, loc, copyBuffer, *alloc)))
|
||||
if (failed(options.createMemCpy(rewriter, loc, copyBuffer, *alloc)))
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Should the buffer be deallocated?
|
||||
AnalysisState analysisState(state.getOptions());
|
||||
AnalysisState analysisState(options);
|
||||
bool dealloc;
|
||||
if (escape().hasValue()) {
|
||||
dealloc = !*escape();
|
||||
} else {
|
||||
// No "escape" annotation found.
|
||||
if (state.getOptions().createDeallocs) {
|
||||
if (options.createDeallocs) {
|
||||
// Perform an ad-hoc analysis.
|
||||
dealloc = !analysisState.isTensorYielded(getResult());
|
||||
} else {
|
||||
|
@ -206,7 +205,7 @@ LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
|
|||
return success();
|
||||
|
||||
rewriter.setInsertionPoint(rewriter.getInsertionBlock()->getTerminator());
|
||||
if (failed(state.getOptions().createDealloc(rewriter, loc, *alloc)))
|
||||
if (failed(options.createDealloc(rewriter, loc, *alloc)))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
@ -627,7 +626,7 @@ void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|||
}
|
||||
|
||||
LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter,
|
||||
BufferizationState &state) {
|
||||
const BufferizationOptions &options) {
|
||||
// Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary.
|
||||
(void)foldToMemrefToTensorPair(rewriter, *this);
|
||||
// Note: The return value of `bufferize` indicates whether there was an error
|
||||
|
|
|
@ -401,7 +401,6 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
|
|||
DenseSet<Operation *> erasedOps;
|
||||
|
||||
// Bufferize all ops.
|
||||
BufferizationState bufferizationState(options);
|
||||
BufferizationRewriter rewriter(op->getContext(), erasedOps, toMemrefOps,
|
||||
worklist, options, opFilter);
|
||||
for (unsigned i = 0; i < worklist.size(); ++i) {
|
||||
|
@ -420,7 +419,7 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
|
|||
continue;
|
||||
// Bufferize the op.
|
||||
rewriter.setInsertionPoint(op);
|
||||
if (failed(bufferizableOp.bufferize(rewriter, bufferizationState)))
|
||||
if (failed(bufferizableOp.bufferize(rewriter, options)))
|
||||
return op->emitError("failed to bufferize op");
|
||||
}
|
||||
|
||||
|
@ -433,7 +432,7 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
|
|||
|
||||
/// Check the result of bufferization. Return an error if an op was not
|
||||
/// bufferized, unless partial bufferization is allowed.
|
||||
if (bufferizationState.getOptions().allowUnknownOps)
|
||||
if (options.allowUnknownOps)
|
||||
return success();
|
||||
|
||||
for (Operation *op : worklist) {
|
||||
|
|
|
@ -258,7 +258,7 @@ struct CallOpInterface
|
|||
/// All function arguments are writable. It is the responsibility of the
|
||||
/// CallOp to insert buffer copies where necessary.
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationOptions &options) const {
|
||||
func::CallOp callOp = cast<func::CallOp>(op);
|
||||
unsigned numResults = callOp.getNumResults();
|
||||
unsigned numOperands = callOp->getNumOperands();
|
||||
|
@ -307,7 +307,7 @@ struct CallOpInterface
|
|||
// Retrieve buffers for tensor operands.
|
||||
Value buffer = newOperands[idx];
|
||||
if (!buffer)
|
||||
buffer = state.getBuffer(rewriter, opOperand.get());
|
||||
buffer = getBuffer(rewriter, opOperand.get(), options);
|
||||
|
||||
// Caller / callee type mismatch is handled with a CastOp.
|
||||
auto memRefType = funcType.getInput(idx);
|
||||
|
@ -364,7 +364,7 @@ struct ReturnOpInterface
|
|||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationOptions &options) const {
|
||||
#ifndef NDEBUG
|
||||
auto returnOp = cast<func::ReturnOp>(op);
|
||||
assert(isa<FuncOp>(returnOp->getParentOp()) &&
|
||||
|
@ -386,11 +386,9 @@ struct FuncOpInterface
|
|||
/// All function bbArgs are writable unless they are explicitly marked as
|
||||
/// read-only. Callers must insert copies when needed.
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationOptions &options) const {
|
||||
auto funcOp = cast<FuncOp>(op);
|
||||
FunctionType funcType = funcOp.getFunctionType();
|
||||
const OneShotBufferizationOptions &options =
|
||||
static_cast<const OneShotBufferizationOptions &>(state.getOptions());
|
||||
|
||||
// Construct the bufferized function type.
|
||||
SmallVector<Type> argTypes;
|
||||
|
|
|
@ -429,7 +429,6 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
|
|||
assert(options.bufferizeFunctionBoundaries &&
|
||||
"expected that function boundary bufferization is activated");
|
||||
IRRewriter rewriter(moduleOp.getContext());
|
||||
BufferizationState bufferizationState(options);
|
||||
|
||||
// A list of functions in the order in which they are analyzed + bufferized.
|
||||
SmallVector<func::FuncOp> orderedFuncOps;
|
||||
|
|
|
@ -20,11 +20,9 @@ using namespace mlir::bufferization;
|
|||
|
||||
namespace {
|
||||
|
||||
// TODO: Ops in the linalg dialect can directly implement this interface.
|
||||
|
||||
/// Generic conversion for any LinalgOp on tensors.
|
||||
static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op,
|
||||
BufferizationState &state) {
|
||||
const BufferizationOptions &options) {
|
||||
// Take a guard before anything else.
|
||||
OpBuilder::InsertionGuard g(rewriter);
|
||||
rewriter.setInsertionPoint(op);
|
||||
|
@ -46,14 +44,14 @@ static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op,
|
|||
newInputBuffers.push_back(opOperand->get());
|
||||
continue;
|
||||
}
|
||||
newInputBuffers.push_back(state.getBuffer(rewriter, opOperand->get()));
|
||||
newInputBuffers.push_back(getBuffer(rewriter, opOperand->get(), options));
|
||||
}
|
||||
|
||||
// New output operands for the cloned op.
|
||||
SmallVector<Value> newOutputBuffers;
|
||||
for (OpResult opResult : op->getOpResults()) {
|
||||
OpOperand *opOperand = op.getOutputOperand(opResult.getResultNumber());
|
||||
Value resultBuffer = state.getBuffer(rewriter, opOperand->get());
|
||||
Value resultBuffer = getBuffer(rewriter, opOperand->get(), options);
|
||||
newOutputBuffers.push_back(resultBuffer);
|
||||
}
|
||||
|
||||
|
@ -123,8 +121,8 @@ struct LinalgOpInterface
|
|||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
return bufferizeLinalgOp(rewriter, cast<LinalgOp>(op), state);
|
||||
const BufferizationOptions &options) const {
|
||||
return bufferizeLinalgOp(rewriter, cast<LinalgOp>(op), options);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -73,7 +73,7 @@ struct ExecuteRegionOpInterface
|
|||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationOptions &options) const {
|
||||
auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
|
||||
|
||||
// Compute new result types.
|
||||
|
@ -81,7 +81,7 @@ struct ExecuteRegionOpInterface
|
|||
for (Type type : executeRegionOp->getResultTypes()) {
|
||||
if (auto tensorType = type.dyn_cast<TensorType>()) {
|
||||
// TODO: Infer the result type instead of computing it.
|
||||
newResultTypes.push_back(getMemRefType(tensorType, state.getOptions()));
|
||||
newResultTypes.push_back(getMemRefType(tensorType, options));
|
||||
} else {
|
||||
newResultTypes.push_back(type);
|
||||
}
|
||||
|
@ -183,7 +183,7 @@ struct IfOpInterface
|
|||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationOptions &options) const {
|
||||
auto ifOp = cast<scf::IfOp>(op);
|
||||
|
||||
// Compute new types of the bufferized scf.if op.
|
||||
|
@ -191,7 +191,7 @@ struct IfOpInterface
|
|||
for (Type returnType : ifOp->getResultTypes()) {
|
||||
if (auto tensorType = returnType.dyn_cast<TensorType>()) {
|
||||
// TODO: Infer the result type instead of computing it.
|
||||
newTypes.push_back(getMemRefType(tensorType, state.getOptions()));
|
||||
newTypes.push_back(getMemRefType(tensorType, options));
|
||||
} else {
|
||||
newTypes.push_back(returnType);
|
||||
}
|
||||
|
@ -309,11 +309,11 @@ static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
|
|||
/// given OpOperands. If an operand is not a tensor, return the original value.
|
||||
static SmallVector<Value> getBuffers(RewriterBase &rewriter,
|
||||
MutableArrayRef<OpOperand> operands,
|
||||
BufferizationState &state) {
|
||||
const BufferizationOptions &options) {
|
||||
SmallVector<Value> result;
|
||||
for (OpOperand &opOperand : operands) {
|
||||
if (opOperand.get().getType().isa<TensorType>()) {
|
||||
Value resultBuffer = state.getBuffer(rewriter, opOperand.get());
|
||||
Value resultBuffer = getBuffer(rewriter, opOperand.get(), options);
|
||||
result.push_back(resultBuffer);
|
||||
} else {
|
||||
result.push_back(opOperand.get());
|
||||
|
@ -325,10 +325,11 @@ static SmallVector<Value> getBuffers(RewriterBase &rewriter,
|
|||
/// Helper function for loop bufferization. Compute the buffer that should be
|
||||
/// yielded from a loop block (loop body or loop condition).
|
||||
static Value getYieldedBuffer(RewriterBase &rewriter, Value tensor,
|
||||
BaseMemRefType type, BufferizationState &state) {
|
||||
BaseMemRefType type,
|
||||
const BufferizationOptions &options) {
|
||||
assert(tensor.getType().isa<TensorType>() && "expected tensor");
|
||||
ensureToMemrefOpIsValid(tensor, type);
|
||||
Value yieldedVal = state.getBuffer(rewriter, tensor);
|
||||
Value yieldedVal = getBuffer(rewriter, tensor, options);
|
||||
return castBuffer(rewriter, yieldedVal, type);
|
||||
}
|
||||
|
||||
|
@ -352,12 +353,12 @@ convertTensorValues(ValueRange values, const DenseSet<int64_t> &tensorIndices,
|
|||
SmallVector<Value> getYieldedValues(RewriterBase &rewriter, ValueRange values,
|
||||
TypeRange bufferizedTypes,
|
||||
const DenseSet<int64_t> &tensorIndices,
|
||||
BufferizationState &state) {
|
||||
const BufferizationOptions &options) {
|
||||
return convertTensorValues(
|
||||
values, tensorIndices, [&](Value val, int64_t index) {
|
||||
return getYieldedBuffer(rewriter, val,
|
||||
bufferizedTypes[index].cast<BaseMemRefType>(),
|
||||
state);
|
||||
options);
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -472,7 +473,7 @@ struct ForOpInterface
|
|||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationOptions &options) const {
|
||||
auto forOp = cast<scf::ForOp>(op);
|
||||
Block *oldLoopBody = &forOp.getLoopBody().front();
|
||||
|
||||
|
@ -482,7 +483,7 @@ struct ForOpInterface
|
|||
|
||||
// The new memref init_args of the loop.
|
||||
SmallVector<Value> initArgs =
|
||||
getBuffers(rewriter, forOp.getIterOpOperands(), state);
|
||||
getBuffers(rewriter, forOp.getIterOpOperands(), options);
|
||||
|
||||
// Construct a new scf.for op with memref instead of tensor values.
|
||||
auto newForOp = rewriter.create<scf::ForOp>(
|
||||
|
@ -511,7 +512,7 @@ struct ForOpInterface
|
|||
auto yieldOp = cast<scf::YieldOp>(loopBody->getTerminator());
|
||||
rewriter.setInsertionPoint(yieldOp);
|
||||
SmallVector<Value> yieldValues = getYieldedValues(
|
||||
rewriter, yieldOp.getResults(), initArgsTypes, indices, state);
|
||||
rewriter, yieldOp.getResults(), initArgsTypes, indices, options);
|
||||
yieldOp.getResultsMutable().assign(yieldValues);
|
||||
|
||||
// Replace loop results.
|
||||
|
@ -704,7 +705,7 @@ struct WhileOpInterface
|
|||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationOptions &options) const {
|
||||
auto whileOp = cast<scf::WhileOp>(op);
|
||||
|
||||
assert(whileOp.getBefore().getBlocks().size() == 1 &&
|
||||
|
@ -722,12 +723,12 @@ struct WhileOpInterface
|
|||
|
||||
// The new memref init_args of the loop.
|
||||
SmallVector<Value> initArgs =
|
||||
getBuffers(rewriter, whileOp->getOpOperands(), state);
|
||||
getBuffers(rewriter, whileOp->getOpOperands(), options);
|
||||
|
||||
// The result types of a WhileOp are the same as the "after" bbArg types.
|
||||
SmallVector<Type> argsTypesAfter = llvm::to_vector(
|
||||
llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) {
|
||||
return state.getBufferType(bbArg).cast<Type>();
|
||||
return getBufferType(bbArg, options).cast<Type>();
|
||||
}));
|
||||
|
||||
// Construct a new scf.while op with memref instead of tensor values.
|
||||
|
@ -761,7 +762,7 @@ struct WhileOpInterface
|
|||
// TODO: This could be relaxed for better bufferization results.
|
||||
SmallVector<Value> newConditionArgs =
|
||||
getYieldedValues(rewriter, newConditionOp.getArgs(), argsTypesAfter,
|
||||
indicesAfter, state);
|
||||
indicesAfter, options);
|
||||
newConditionOp.getArgsMutable().assign(newConditionArgs);
|
||||
|
||||
// Set up new iter_args and move the loop body block to the new op.
|
||||
|
@ -780,7 +781,7 @@ struct WhileOpInterface
|
|||
// TODO: This could be relaxed for better bufferization results.
|
||||
SmallVector<Value> newYieldValues =
|
||||
getYieldedValues(rewriter, newYieldOp.getResults(), argsTypesBefore,
|
||||
indicesBefore, state);
|
||||
indicesBefore, options);
|
||||
newYieldOp.getResultsMutable().assign(newYieldValues);
|
||||
|
||||
// Replace loop results.
|
||||
|
@ -866,7 +867,7 @@ struct YieldOpInterface
|
|||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationOptions &options) const {
|
||||
auto yieldOp = cast<scf::YieldOp>(op);
|
||||
if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::ForOp, scf::WhileOp>(
|
||||
yieldOp->getParentOp()))
|
||||
|
@ -954,7 +955,7 @@ struct ForeachThreadOpInterface
|
|||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &b,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationOptions &options) const {
|
||||
OpBuilder::InsertionGuard g(b);
|
||||
auto foreachThreadOp = cast<ForeachThreadOp>(op);
|
||||
|
||||
|
@ -966,7 +967,7 @@ struct ForeachThreadOpInterface
|
|||
// Insert copies right before the PerformConcurrentlyOp terminator. They
|
||||
// should not be inside terminator (which would be the default insertion
|
||||
// point).
|
||||
Value buffer = state.getBuffer(b, insertDest->get());
|
||||
Value buffer = getBuffer(b, insertDest->get(), options);
|
||||
newResults.push_back(buffer);
|
||||
}
|
||||
|
||||
|
@ -991,8 +992,7 @@ struct ForeachThreadOpInterface
|
|||
performConcurrentlyOp.walk([&](ParallelInsertSliceOp insertOp) {
|
||||
Location loc = insertOp.getLoc();
|
||||
Type srcType = getMemRefType(
|
||||
insertOp.getSource().getType().cast<RankedTensorType>(),
|
||||
state.getOptions());
|
||||
insertOp.getSource().getType().cast<RankedTensorType>(), options);
|
||||
// ParallelInsertSliceOp bufferizes to a copy.
|
||||
auto srcMemref = b.create<bufferization::ToMemrefOp>(
|
||||
loc, srcType, insertOp.getSource());
|
||||
|
@ -1001,8 +1001,8 @@ struct ForeachThreadOpInterface
|
|||
loc, destMemref, insertOp.getMixedOffsets(),
|
||||
insertOp.getMixedSizes(), insertOp.getMixedStrides());
|
||||
// This memcpy will fold away if everything bufferizes in-place.
|
||||
if (failed(state.getOptions().createMemCpy(b, insertOp.getLoc(),
|
||||
srcMemref, subview)))
|
||||
if (failed(options.createMemCpy(b, insertOp.getLoc(), srcMemref,
|
||||
subview)))
|
||||
return WalkResult::interrupt();
|
||||
b.eraseOp(insertOp);
|
||||
return WalkResult::advance();
|
||||
|
@ -1022,7 +1022,7 @@ struct PerformConcurrentlyOpInterface
|
|||
: public BufferizableOpInterface::ExternalModel<
|
||||
PerformConcurrentlyOpInterface, PerformConcurrentlyOp> {
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &b,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationOptions &options) const {
|
||||
llvm_unreachable("op does not have any tensor OpOperands / OpResults");
|
||||
return failure();
|
||||
}
|
||||
|
@ -1110,7 +1110,7 @@ struct ParallelInsertSliceOpInterface
|
|||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &b,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationOptions &options) const {
|
||||
// Will be bufferized as part of ForeachThreadOp.
|
||||
return failure();
|
||||
}
|
||||
|
|
|
@ -59,7 +59,7 @@ struct AssumingOpInterface
|
|||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationOptions &options) const {
|
||||
auto assumingOp = cast<shape::AssumingOp>(op);
|
||||
|
||||
// Compute new result types.
|
||||
|
@ -67,7 +67,7 @@ struct AssumingOpInterface
|
|||
for (Type type : assumingOp->getResultTypes()) {
|
||||
if (auto tensorType = type.dyn_cast<TensorType>()) {
|
||||
// TODO: Infer the result type instead of computing it.
|
||||
newResultTypes.push_back(getMemRefType(tensorType, state.getOptions()));
|
||||
newResultTypes.push_back(getMemRefType(tensorType, options));
|
||||
} else {
|
||||
newResultTypes.push_back(type);
|
||||
}
|
||||
|
@ -152,7 +152,7 @@ struct AssumingYieldOpInterface
|
|||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationOptions &options) const {
|
||||
// Op is bufferized as part of AssumingOp.
|
||||
return failure();
|
||||
}
|
||||
|
|
|
@ -48,11 +48,11 @@ struct CastOpInterface
|
|||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationOptions &options) const {
|
||||
auto castOp = cast<tensor::CastOp>(op);
|
||||
|
||||
// The result buffer still has the old (pre-cast) type.
|
||||
Value resultBuffer = state.getBuffer(rewriter, castOp.source());
|
||||
Value resultBuffer = getBuffer(rewriter, castOp.source(), options);
|
||||
auto sourceMemRefType = resultBuffer.getType().cast<BaseMemRefType>();
|
||||
Attribute memorySpace = sourceMemRefType.getMemorySpace();
|
||||
TensorType resultTensorType =
|
||||
|
@ -64,8 +64,8 @@ struct CastOpInterface
|
|||
layout = rankedMemRefType.getLayout();
|
||||
|
||||
// Compute the new memref type.
|
||||
Type resultMemRefType = getMemRefType(resultTensorType, state.getOptions(),
|
||||
layout, memorySpace);
|
||||
Type resultMemRefType =
|
||||
getMemRefType(resultTensorType, options, layout, memorySpace);
|
||||
|
||||
// Replace the op with a memref.cast.
|
||||
assert(memref::CastOp::areCastCompatible(resultBuffer.getType(),
|
||||
|
@ -105,10 +105,10 @@ struct CollapseShapeOpInterface
|
|||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationOptions &options) const {
|
||||
auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
|
||||
RankedTensorType tensorResultType = collapseShapeOp.getResultType();
|
||||
Value buffer = state.getBuffer(rewriter, collapseShapeOp.src());
|
||||
Value buffer = getBuffer(rewriter, collapseShapeOp.src(), options);
|
||||
auto bufferType = buffer.getType().cast<MemRefType>();
|
||||
|
||||
if (tensorResultType.getRank() == 0) {
|
||||
|
@ -146,7 +146,7 @@ struct CollapseShapeOpInterface
|
|||
bufferType, collapseShapeOp.getReassociationIndices());
|
||||
if (!canBeCollapsed) {
|
||||
// TODO: Create alloc_tensor ops during TensorCopyInsertion.
|
||||
AnalysisState analysisState(state.getOptions());
|
||||
AnalysisState analysisState(options);
|
||||
Value tensorAlloc = allocateTensorForShapedValue(
|
||||
rewriter, op->getLoc(), collapseShapeOp.src(),
|
||||
analysisState.isTensorYielded(collapseShapeOp.result()));
|
||||
|
@ -185,9 +185,9 @@ struct DimOpInterface
|
|||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationOptions &options) const {
|
||||
auto dimOp = cast<tensor::DimOp>(op);
|
||||
auto v = state.getBuffer(rewriter, dimOp.source());
|
||||
auto v = getBuffer(rewriter, dimOp.source(), options);
|
||||
replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v, dimOp.index());
|
||||
return success();
|
||||
}
|
||||
|
@ -220,10 +220,10 @@ struct ExpandShapeOpInterface
|
|||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationOptions &options) const {
|
||||
auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
|
||||
auto tensorResultType = expandShapeOp.getResultType();
|
||||
auto buffer = state.getBuffer(rewriter, expandShapeOp.src());
|
||||
auto buffer = getBuffer(rewriter, expandShapeOp.src(), options);
|
||||
|
||||
// Memref result type is inferred by the builder based on reassociation
|
||||
// indices and result shape.
|
||||
|
@ -261,13 +261,13 @@ struct ExtractSliceOpInterface
|
|||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationOptions &options) const {
|
||||
auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
|
||||
Location loc = extractSliceOp.getLoc();
|
||||
|
||||
// Even if this op was decided to bufferize out-of-place, do not insert the
|
||||
// buffer copy yet. This is done later in this function.
|
||||
auto srcMemref = state.getBuffer(rewriter, extractSliceOp.source());
|
||||
auto srcMemref = getBuffer(rewriter, extractSliceOp.source(), options);
|
||||
auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
|
||||
auto dstTensorType =
|
||||
extractSliceOp.result().getType().cast<RankedTensorType>();
|
||||
|
@ -319,9 +319,9 @@ struct ExtractOpInterface
|
|||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationOptions &options) const {
|
||||
auto extractOp = cast<tensor::ExtractOp>(op);
|
||||
Value srcMemref = state.getBuffer(rewriter, extractOp.tensor());
|
||||
Value srcMemref = getBuffer(rewriter, extractOp.tensor(), options);
|
||||
replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref,
|
||||
extractOp.indices());
|
||||
return success();
|
||||
|
@ -355,7 +355,7 @@ struct FromElementsOpInterface
|
|||
: public BufferizableOpInterface::ExternalModel<FromElementsOpInterface,
|
||||
tensor::FromElementsOp> {
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationOptions &options) const {
|
||||
auto fromElementsOp = cast<tensor::FromElementsOp>(op);
|
||||
|
||||
// Allocate a buffer for the result.
|
||||
|
@ -363,7 +363,7 @@ struct FromElementsOpInterface
|
|||
auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
|
||||
auto shape = tensorType.getShape();
|
||||
// TODO: Create alloc_tensor ops during TensorCopyInsertion.
|
||||
AnalysisState analysisState(state.getOptions());
|
||||
AnalysisState analysisState(options);
|
||||
Value tensorAlloc = allocateTensorForShapedValue(
|
||||
rewriter, loc, fromElementsOp.result(),
|
||||
analysisState.isTensorYielded(fromElementsOp.result()),
|
||||
|
@ -410,13 +410,13 @@ struct GenerateOpInterface
|
|||
: public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
|
||||
tensor::GenerateOp> {
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationOptions &options) const {
|
||||
auto generateOp = cast<tensor::GenerateOp>(op);
|
||||
auto tensorType = generateOp.getType().cast<RankedTensorType>();
|
||||
// Allocate memory.
|
||||
Location loc = op->getLoc();
|
||||
// TODO: Create alloc_tensor ops during TensorCopyInsertion.
|
||||
AnalysisState analysisState(state.getOptions());
|
||||
AnalysisState analysisState(options);
|
||||
Value tensorAlloc = allocateTensorForShapedValue(
|
||||
rewriter, loc, generateOp.result(),
|
||||
analysisState.isTensorYielded(generateOp.result()),
|
||||
|
@ -493,9 +493,9 @@ struct InsertOpInterface
|
|||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationOptions &options) const {
|
||||
auto insertOp = cast<tensor::InsertOp>(op);
|
||||
Value destMemref = state.getBuffer(rewriter, insertOp.dest());
|
||||
Value destMemref = getBuffer(rewriter, insertOp.dest(), options);
|
||||
rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.scalar(),
|
||||
destMemref, insertOp.indices());
|
||||
replaceOpWithBufferizedValues(rewriter, op, destMemref);
|
||||
|
@ -645,7 +645,7 @@ struct InsertSliceOpInterface
|
|||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationOptions &options) const {
|
||||
// insert_slice ops arise from tiling and bufferizing them out-of-place is
|
||||
// generally a deal breaker. When used with loops, this ends up cloning the
|
||||
// whole tensor on every single iteration and is a symptom of a
|
||||
|
@ -653,7 +653,7 @@ struct InsertSliceOpInterface
|
|||
// TODO: be very loud about it or even consider failing the pass.
|
||||
auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
|
||||
Location loc = insertSliceOp.getLoc();
|
||||
Value dstMemref = state.getBuffer(rewriter, insertSliceOp.dest());
|
||||
Value dstMemref = getBuffer(rewriter, insertSliceOp.dest(), options);
|
||||
|
||||
// Expand offsets, sizes and strides to the full rank to handle the
|
||||
// rank-reducing case.
|
||||
|
@ -681,9 +681,8 @@ struct InsertSliceOpInterface
|
|||
|
||||
// Copy tensor. If this tensor.insert_slice has a matching
|
||||
// tensor.extract_slice, the copy operation will eventually fold away.
|
||||
auto srcMemref = state.getBuffer(rewriter, insertSliceOp.source());
|
||||
if (failed(
|
||||
state.getOptions().createMemCpy(rewriter, loc, srcMemref, subView)))
|
||||
auto srcMemref = getBuffer(rewriter, insertSliceOp.source(), options);
|
||||
if (failed(options.createMemCpy(rewriter, loc, srcMemref, subView)))
|
||||
return failure();
|
||||
|
||||
replaceOpWithBufferizedValues(rewriter, op, dstMemref);
|
||||
|
@ -711,9 +710,9 @@ struct RankOpInterface
|
|||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationOptions &options) const {
|
||||
auto rankOp = cast<tensor::RankOp>(op);
|
||||
auto v = state.getBuffer(rewriter, rankOp.tensor());
|
||||
auto v = getBuffer(rewriter, rankOp.tensor(), options);
|
||||
replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
|
||||
v);
|
||||
return success();
|
||||
|
@ -747,12 +746,12 @@ struct ReshapeOpInterface
|
|||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationOptions &options) const {
|
||||
auto reshapeOp = cast<tensor::ReshapeOp>(op);
|
||||
Value srcBuffer = state.getBuffer(rewriter, reshapeOp.source());
|
||||
Value shapeBuffer = state.getBuffer(rewriter, reshapeOp.shape());
|
||||
Value srcBuffer = getBuffer(rewriter, reshapeOp.source(), options);
|
||||
Value shapeBuffer = getBuffer(rewriter, reshapeOp.shape(), options);
|
||||
auto resultTensorType = reshapeOp.getResult().getType().cast<TensorType>();
|
||||
auto resultMemRefType = getMemRefType(resultTensorType, state.getOptions());
|
||||
auto resultMemRefType = getMemRefType(resultTensorType, options);
|
||||
replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
|
||||
rewriter, op, resultMemRefType, srcBuffer, shapeBuffer);
|
||||
return success();
|
||||
|
|
|
@ -46,11 +46,11 @@ struct TransferReadOpInterface
|
|||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationOptions &options) const {
|
||||
auto readOp = cast<vector::TransferReadOp>(op);
|
||||
assert(readOp.getShapedType().isa<TensorType>() &&
|
||||
"only tensor types expected");
|
||||
Value buffer = state.getBuffer(rewriter, readOp.getSource());
|
||||
Value buffer = getBuffer(rewriter, readOp.getSource(), options);
|
||||
replaceOpWithNewBufferizedOp<vector::TransferReadOp>(
|
||||
rewriter, readOp, readOp.getVectorType(), buffer, readOp.getIndices(),
|
||||
readOp.getPermutationMap(), readOp.getPadding(), readOp.getMask(),
|
||||
|
@ -91,13 +91,13 @@ struct TransferWriteOpInterface
|
|||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
const BufferizationOptions &options) const {
|
||||
auto writeOp = cast<vector::TransferWriteOp>(op);
|
||||
assert(writeOp.getShapedType().isa<TensorType>() &&
|
||||
"only tensor types expected");
|
||||
|
||||
// Create a new transfer_write on buffer that doesn't have a return value.
|
||||
Value resultBuffer = state.getBuffer(rewriter, writeOp.getSource());
|
||||
Value resultBuffer = getBuffer(rewriter, writeOp.getSource(), options);
|
||||
rewriter.create<vector::TransferWriteOp>(
|
||||
writeOp.getLoc(), writeOp.getVector(), resultBuffer,
|
||||
writeOp.getIndices(), writeOp.getPermutationMapAttr(),
|
||||
|
|
Loading…
Reference in New Issue