diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h index a9a344d75c1c..7ef016f7c5dd 100644 --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h @@ -11,6 +11,7 @@ #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Operation.h" #include "mlir/Support/LLVM.h" @@ -240,9 +241,8 @@ struct AllocationCallbacks { /// BufferizationState keeps track of bufferization state and provides access to /// the results of the analysis. struct BufferizationState { - BufferizationState(BufferizationAliasInfo &aliasInfo, - AllocationCallbacks &allocationFns) - : aliasInfo(aliasInfo), allocationFns(allocationFns) {} + BufferizationState(ModuleOp moduleOp, AllocationCallbacks &allocationFns) + : aliasInfo(moduleOp), allocationFns(allocationFns) {} // BufferizationState should be passed as a reference. BufferizationState(const BufferizationState &) = delete; @@ -270,8 +270,11 @@ struct BufferizationState { /// Mark `op` as obsolete, so that it is deleted after bufferization. void markOpObsolete(Operation *op); + /// Erase all ops that were marked obsolete. + void eraseObsoleteOps(); + /// `aliasInfo` keeps track of aliasing and equivalent values. - BufferizationAliasInfo &aliasInfo; + BufferizationAliasInfo aliasInfo; /// `allocationFns` contains helper functions for creating alloc ops, dealloc /// ops and memcpy ops. @@ -283,6 +286,10 @@ struct BufferizationState { /// Obsolete ops that should be deleted after bufferization. SmallVector obsoleteOps; + + /// A map for looking up bufferized function types. + // TODO: Entangle function calls and FuncOps from the remaining bufferization. + DenseMap bufferizedFunctionTypes; }; /// Return the result buffer (memref) for a given OpResult (tensor). Allocate diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h index 653ec7b36eb8..6db6ba6db3c5 100644 --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h @@ -25,11 +25,7 @@ static constexpr int64_t kBufferAlignments = 128; std::unique_ptr defaultAllocationCallbacks(); /// Bufferize one particular op. -/// `bufferizedFunctionTypes` (resp. `globalCreator`) are expected to be -/// non-null if `op` is a CallOpInterface (resp. GlobalCreator). -LogicalResult -bufferizeOp(Operation *op, BufferizationState &state, - DenseMap *bufferizedFunctionTypes = nullptr); +LogicalResult bufferizeOp(Operation *op, BufferizationState &state); /// Register external models implemented for the `BufferizableOpInterface`. void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp index 150ffd7e45f3..630415bf469d 100644 --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp @@ -470,3 +470,10 @@ void mlir::linalg::comprehensive_bufferize::BufferizationState::markOpObsolete( Operation *op) { obsoleteOps.push_back(op); } + +void mlir::linalg::comprehensive_bufferize::BufferizationState:: + eraseObsoleteOps() { + for (Operation *op : obsoleteOps) + op->erase(); + obsoleteOps.clear(); +} diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp index 22f5493b6b80..ea9309e01d87 100644 --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -783,144 +783,6 @@ static Value createNewAllocDeallocPairForShapedValue( // Bufferization as simple BlockAndValueMapping rewrites. //===----------------------------------------------------------------------===// -/// In a first approximation, all the function arguments of a FuncOp are marked -/// inplaceable. For now, it is the responsibility of the `callOp` bufferization -/// to allow FuncOp that are inplaceable to write inPlace. -static LogicalResult -bufferize(OpBuilder &b, CallOpInterface callOp, BufferizationState &state, - DenseMap &bufferizedFunctionTypes) { - FuncOp funcOp = getCalledFunction(callOp); - assert(isa(callOp.getOperation()) && funcOp && - "expected Callop to a FuncOp"); - - // If nothing to do then we are done. - if (!llvm::any_of(funcOp.getType().getInputs(), isaTensor) && - !llvm::any_of(funcOp.getType().getResults(), isaTensor)) - return success(); - - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(callOp); - - // 1. Filter return types: - // - if the callee is bodiless / external, we cannot inspect it and we - // cannot assume anything. We can just assert that it does not return a - // tensor as this would have to bufferize to "return a memref", whose - // semantics is ill-defined. - // - if the callee has a body, we perform inter-procedural equivalence - // analysis. When successful, a result folds onto an operand. When - // unsuccessful, additional work is needed to either: - // * hoist a result into an inplaceable operand or - // * devise a better representation to truly return a buffer. - SmallVector resultTypes; - SmallVector hoistedArguments; - if (funcOp.body().empty()) { - if (llvm::any_of(funcOp.getType().getResults(), isaTensor)) - return callOp->emitError() - << "cannot bufferize bodiless function that returns a tensor"; - } else { - ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); - assert(returnOp && "expected func with single return op"); - - // For each FuncOp result, keep track of which inplace argument it reuses. - for (OpOperand &returnOperand : returnOp->getOpOperands()) { - Type returnType = returnOperand.get().getType(); - if (!isaTensor(returnType)) { - resultTypes.push_back(returnType); - continue; - } - - // If return operand is equivalent to some bbArg, no need to return it. - Value returnVal = returnOperand.get(); - if (BlockArgument bbArg = - getEquivalentEnclosingFuncBBArg(returnVal, state.aliasInfo)) { - Value oldRes = callOp->getResult(returnOperand.getOperandNumber()); - int64_t idx = bbArg.getArgNumber(); - Value buffer = state.lookupBuffer(callOp->getOperand(idx)); - // Add CallOp operand/result equivalence: this is interprocedural info. - state.aliasInfo.insertNewBufferEquivalence(oldRes, buffer); - state.mapBuffer(oldRes, buffer); - // Add a TensorLoadOp to kill all uses of the CallOp return. - // Replace all uses of the CallOp results so we can erase the CallOp. - // This TensorLoadOp must fold/DCE away or bufferization should be - // considered failed. - Value tensorLoad = - b.create(callOp.getLoc(), buffer); - oldRes.replaceAllUsesWith(tensorLoad); - // Add new op equivalence info. - state.aliasInfo.insertNewBufferEquivalence(tensorLoad, buffer); - state.mapBuffer(tensorLoad, buffer); - continue; - } - - // TODO: Need to hoist above function boundary. - if (Operation *allocOp = getEquivalentAlloc(returnVal, state.aliasInfo)) { - hoistedArguments.push_back(allocOp->getResult(0)); - continue; - } - - // Other cases legitimately need to return a tensor, this is currently not - // supported. For instance, if hoisting across function boundary has - // failed, it may be due to e.g. data-dependent sizes. In such a case, we - // would we need a better type than memref. - resultTypes.push_back(returnType); - - int64_t returnIdx = returnOperand.getOperandNumber(); - return returnOp->emitError() - << "buffer result #" << returnIdx << " not produced by an alloc\n"; - } - } - - // 2. Compute bufferized FunctionType. - SmallVector argumentTypes{callOp->getOperandTypes()}; - ValueRange hoistedArgs{hoistedArguments}; - llvm::append_range(argumentTypes, hoistedArgs.getTypes()); - // Get the bufferized FunctionType for funcOp or construct it if not yet - // available. - FunctionType bufferizedFuncType = getOrCreateBufferizedFunctionType( - funcOp, argumentTypes, resultTypes, bufferizedFunctionTypes); - - // 3. Rewrite tensor operands as memrefs based on `bufferizedFuncType`. - SmallVector newOperands; - newOperands.reserve(callOp->getNumOperands()); - for (OpOperand &opOperand : callOp->getOpOperands()) { - Value tensorOperand = opOperand.get(); - // Non-tensor operands are just copied. - if (!tensorOperand.getType().isa()) { - newOperands.push_back(tensorOperand); - continue; - } - - // Tensor operands are guaranteed to have been buferized. - int64_t idx = opOperand.getOperandNumber(); - Value buffer = state.lookupBuffer(tensorOperand); - - // Caller / callee type mistmatch is handled with a CastOp. - auto memRefType = bufferizedFuncType.getInput(idx); - // Since we don't yet have a clear layout story, buffer_cast may - // conservatively turn tensors into more dynamic memref than necessary. - // If the memref type of the callee fails, introduce an extra memref.cast - // that will either canonicalize away or fail compilation until we can do - // something better. - if (buffer.getType() != memRefType) { - Value castBuffer = - b.create(callOp.getLoc(), memRefType, buffer); - // Add new op equivalence info. - state.aliasInfo.insertNewBufferEquivalence(castBuffer, buffer); - state.mapBuffer(tensorOperand, castBuffer); - buffer = castBuffer; - } - newOperands.push_back(buffer); - } - - // 4. Create the new CallOp. - Operation *newCallOp = b.create(callOp.getLoc(), funcOp.sym_name(), - resultTypes, newOperands); - newCallOp->setAttrs(callOp->getAttrs()); - // Delete the op at the end of bufferization. - return success(); -} - /// FuncOp always creates TensorToMemRef ops. static LogicalResult bufferize(OpBuilder &b, FuncOp funcOp, BufferizationState &state) { @@ -1065,20 +927,11 @@ inPlaceAnalysisFuncOpBody(FuncOp funcOp, BufferizationAliasInfo &aliasInfo, // Bufferization entry-point for functions. //===----------------------------------------------------------------------===// -LogicalResult mlir::linalg::comprehensive_bufferize::bufferizeOp( - Operation *op, BufferizationState &state, - DenseMap *bufferizedFunctionTypes) { +LogicalResult +mlir::linalg::comprehensive_bufferize::bufferizeOp(Operation *op, + BufferizationState &state) { OpBuilder b(op->getContext()); - // CallOps are handled separately. - if (auto callOp = dyn_cast(op)) { - LDBG("Begin bufferize:\n" << callOp << '\n'); - if (!bufferizedFunctionTypes) - llvm_unreachable( - "null bufferizedFunctionTypes when bufferizing CallOpInterface"); - return bufferize(b, callOp, state, *bufferizedFunctionTypes); - } - // Skip BufferCast and TensorLoad ops. if (isa(op)) return success(); @@ -1098,9 +951,8 @@ LogicalResult mlir::linalg::comprehensive_bufferize::bufferizeOp( return op->emitError() << "unsupported op with tensors"; } -static LogicalResult bufferizeFuncOpInternals( - FuncOp funcOp, BufferizationState &state, - DenseMap &bufferizedFunctionTypes) { +static LogicalResult bufferizeFuncOpInternals(FuncOp funcOp, + BufferizationState &state) { LLVM_DEBUG(llvm::dbgs() << "\n\n"); LDBG("Begin BufferizeFuncOpInternals:\n" << funcOp << '\n'); OpBuilder b(funcOp->getContext()); @@ -1109,19 +961,9 @@ static LogicalResult bufferizeFuncOpInternals( if (failed(bufferize(b, funcOp, state))) return failure(); - // Cannot erase ops during the traversal. Do that afterwards. - SmallVector toErase; - auto walkFunc = [&](Operation *op) -> WalkResult { - if (failed(bufferizeOp(op, state, &bufferizedFunctionTypes))) + if (failed(bufferizeOp(op, state))) return failure(); - - // Register post-walk erasure, if necessary. - if (isa(op)) - if (llvm::any_of(op->getOperandTypes(), isaTensor) || - llvm::any_of(op->getResultTypes(), isaTensor)) - toErase.push_back(op); - return success(); }; @@ -1133,9 +975,6 @@ static LogicalResult bufferizeFuncOpInternals( LDBG("End BufferizeFuncOpInternals:\n" << funcOp << '\n'); - for (Operation *op : toErase) - op->erase(); - return success(); } @@ -1516,12 +1355,12 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize( ModuleOp moduleOp, const BufferizationOptions &options) { SmallVector orderedFuncOps; DenseMap> callerMap; - DenseMap bufferizedFunctionTypes; if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap))) return failure(); DominanceInfo domInfo(moduleOp); - BufferizationAliasInfo aliasInfo(moduleOp); + BufferizationState state(moduleOp, *options.allocationFns); + BufferizationAliasInfo &aliasInfo = state.aliasInfo; // Interestingly, all function args that are not visible outside of a module // can be fully bufferized inplace by guaranteeing the CallOp is bufferized @@ -1564,16 +1403,12 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize( // Bufferization phase. if (!options.testAnalysisOnly) { - BufferizationState state(aliasInfo, *options.allocationFns); - // Bufferize all ops in funcOp. - if (failed( - bufferizeFuncOpInternals(funcOp, state, bufferizedFunctionTypes))) + if (failed(bufferizeFuncOpInternals(funcOp, state))) return failure(); // Erase all obsolete ops. - for (Operation *op : state.obsoleteOps) - op->erase(); + state.eraseObsoleteOps(); } } // Annotate operations if we only want to report the analysis. @@ -1586,7 +1421,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize( // Note: It would be good to apply cleanups here but we cannot as aliasInfo // would be invalidated. if (failed(bufferizeFuncOpBoundary(funcOp, aliasInfo, - bufferizedFunctionTypes))) + state.bufferizedFunctionTypes))) return failure(); if (!options.allowReturnMemref && @@ -1986,10 +1821,142 @@ struct CallOpInterface return BufferRelation::Equivalent; } + /// In a first approximation, all the function arguments of a FuncOp are + /// marked inplaceable. For now, it is the responsibility of the `callOp` + /// bufferization to allow FuncOp that are inplaceable to write inPlace. LogicalResult bufferize(Operation *op, OpBuilder &b, BufferizationState &state) const { - llvm_unreachable("CallOps are handled separately"); - return failure(); + CallOp callOp = cast(op); + FuncOp funcOp = getCalledFunction(callOp); + assert(isa(callOp.getOperation()) && funcOp && + "expected Callop to a FuncOp"); + + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(callOp); + + // 1. Filter return types: + // - if the callee is bodiless / external, we cannot inspect it and we + // cannot assume anything. We can just assert that it does not return a + // tensor as this would have to bufferize to "return a memref", whose + // semantics is ill-defined. + // - if the callee has a body, we perform inter-procedural equivalence + // analysis. When successful, a result folds onto an operand. When + // unsuccessful, additional work is needed to either: + // * hoist a result into an inplaceable operand or + // * devise a better representation to truly return a buffer. + SmallVector resultTypes; + SmallVector hoistedArguments; + if (funcOp.body().empty()) { + if (llvm::any_of(funcOp.getType().getResults(), isaTensor)) + return callOp->emitError() + << "cannot bufferize bodiless function that returns a tensor"; + } else { + ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); + assert(returnOp && "expected func with single return op"); + + // For each FuncOp result, keep track of which inplace argument it reuses. + for (OpOperand &returnOperand : returnOp->getOpOperands()) { + Type returnType = returnOperand.get().getType(); + if (!isaTensor(returnType)) { + resultTypes.push_back(returnType); + continue; + } + + // If return operand is equivalent to some bbArg, no need to return it. + Value returnVal = returnOperand.get(); + if (BlockArgument bbArg = + getEquivalentEnclosingFuncBBArg(returnVal, state.aliasInfo)) { + Value oldRes = callOp->getResult(returnOperand.getOperandNumber()); + int64_t idx = bbArg.getArgNumber(); + Value buffer = state.lookupBuffer(callOp->getOperand(idx)); + // Add CallOp operand/result equivalence: this is interprocedural + // info. + state.aliasInfo.insertNewBufferEquivalence(oldRes, buffer); + state.mapBuffer(oldRes, buffer); + // Add a TensorLoadOp to kill all uses of the CallOp return. + // Replace all uses of the CallOp results so we can erase the CallOp. + // This TensorLoadOp must fold/DCE away or bufferization should be + // considered failed. + Value tensorLoad = + b.create(callOp.getLoc(), buffer); + oldRes.replaceAllUsesWith(tensorLoad); + // Add new op equivalence info. + state.aliasInfo.insertNewBufferEquivalence(tensorLoad, buffer); + state.mapBuffer(tensorLoad, buffer); + continue; + } + + // TODO: Need to hoist above function boundary. + if (Operation *allocOp = + getEquivalentAlloc(returnVal, state.aliasInfo)) { + hoistedArguments.push_back(allocOp->getResult(0)); + continue; + } + + // Other cases legitimately need to return a tensor, this is currently + // not supported. For instance, if hoisting across function boundary has + // failed, it may be due to e.g. data-dependent sizes. In such a case, + // we would we need a better type than memref. + resultTypes.push_back(returnType); + + int64_t returnIdx = returnOperand.getOperandNumber(); + return returnOp->emitError() << "buffer result #" << returnIdx + << " not produced by an alloc\n"; + } + } + + // 2. Compute bufferized FunctionType. + SmallVector argumentTypes{callOp->getOperandTypes()}; + ValueRange hoistedArgs{hoistedArguments}; + llvm::append_range(argumentTypes, hoistedArgs.getTypes()); + // Get the bufferized FunctionType for funcOp or construct it if not yet + // available. + FunctionType bufferizedFuncType = getOrCreateBufferizedFunctionType( + funcOp, argumentTypes, resultTypes, state.bufferizedFunctionTypes); + + // 3. Rewrite tensor operands as memrefs based on `bufferizedFuncType`. + SmallVector newOperands; + newOperands.reserve(callOp->getNumOperands()); + for (OpOperand &opOperand : callOp->getOpOperands()) { + Value tensorOperand = opOperand.get(); + // Non-tensor operands are just copied. + if (!tensorOperand.getType().isa()) { + newOperands.push_back(tensorOperand); + continue; + } + + // Tensor operands are guaranteed to have been buferized. + int64_t idx = opOperand.getOperandNumber(); + Value buffer = state.lookupBuffer(tensorOperand); + + // Caller / callee type mistmatch is handled with a CastOp. + auto memRefType = bufferizedFuncType.getInput(idx); + // Since we don't yet have a clear layout story, buffer_cast may + // conservatively turn tensors into more dynamic memref than necessary. + // If the memref type of the callee fails, introduce an extra memref.cast + // that will either canonicalize away or fail compilation until we can do + // something better. + if (buffer.getType() != memRefType) { + Value castBuffer = + b.create(callOp.getLoc(), memRefType, buffer); + // Add new op equivalence info. + state.aliasInfo.insertNewBufferEquivalence(castBuffer, buffer); + state.mapBuffer(tensorOperand, castBuffer); + buffer = castBuffer; + } + newOperands.push_back(buffer); + } + + // 4. Create the new CallOp. + Operation *newCallOp = b.create(callOp.getLoc(), funcOp.sym_name(), + resultTypes, newOperands); + newCallOp->setAttrs(callOp->getAttrs()); + + // 5. Delete the op at the end of bufferization. + state.markOpObsolete(callOp); + + return success(); } };