From 268524238e903261231c6dafca65d9831e3ca34c Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Wed, 26 Jan 2022 23:53:55 +0900 Subject: [PATCH] [mlir][bufferization] Add an option to use memref types without layout maps This is for compatibility with existing bufferization passes. Also clean up memref type generation a bit. Differential Revision: https://reviews.llvm.org/D118243 --- .../IR/BufferizableOpInterface.h | 18 +++--- mlir/include/mlir/Dialect/Linalg/Passes.td | 3 + .../IR/BufferizableOpInterface.cpp | 60 +++++++++++-------- .../ModuleBufferization.cpp | 21 +++---- .../SCFInterfaceImpl.cpp | 14 ++--- .../Transforms/ComprehensiveBufferizePass.cpp | 1 + .../BufferizableOpInterfaceImpl.cpp | 13 +--- ...omprehensive-module-bufferize-partial.mlir | 19 ++++-- .../comprehensive-module-bufferize.mlir | 7 +++ .../Linalg/TestComprehensiveBufferize.cpp | 5 ++ 10 files changed, 94 insertions(+), 67 deletions(-) diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h index bbac6e59aeeb..5107710413ea 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -98,6 +98,10 @@ struct BufferizationOptions { /// Should be used only with `testAnalysisOnly = true`. unsigned analysisFuzzerSeed = 0; + /// Specifies whether fully dynamic layout maps should be used on ranked + /// MemRef types. If false, MemRef types will have no layout maps. + bool fullyDynamicLayoutMaps = true; + /// If set to `true`, does not modify the IR apart from adding attributes (for /// checking the results of the analysis) and post analysis steps. bool testAnalysisOnly = false; @@ -282,21 +286,17 @@ OpTy replaceOpWithNewBufferizedOp(RewriterBase &rewriter, Operation *op, } /// Return a contiguous MemRefType (i.e. with canonical/empty layout map) -/// with the same shape as `shapedType` and specified `layout` and -/// `addressSpace`. +/// with the same shape as `shapedType` and specified `addressSpace`. MemRefType getContiguousMemRefType(ShapedType shapedType, - MemRefLayoutAttrInterface layout = {}, Attribute memorySpace = {}); -/// Return an UnrankedMemRefType with the given element type and memory space. -UnrankedMemRefType getUnrankedMemRefType(Type elementType, - Attribute memorySpace = {}); - /// Return a MemRefType to which the `tensorType` can be bufferized in a /// composable fashion. The layout must be the most dynamic possible and /// canonicalize away once bufferization is finished. -MemRefType getDynamicMemRefType(RankedTensorType tensorType, - unsigned addressSpace = 0); +BaseMemRefType getMemRefType(TensorType tensorType, + const BufferizationOptions &options, + MemRefLayoutAttrInterface layout = {}, + Attribute memorySpace = {}); /// Creates a memref allocation with the given type and dynamic extents. FailureOr createAlloc(OpBuilder &b, Location loc, MemRefType type, diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td index c67ebc84a5cf..fac60e2fd2b8 100644 --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -55,6 +55,9 @@ def LinalgComprehensiveModuleBufferize : Option<"useLinalgCopy", "use-linalg-copy", "bool", /*default=*/"false", "Use a copy operation implemented as a Linalg op.">, + Option<"fullyDynamicLayoutMaps", "fully-dynamic-layout-maps", "bool", + /*default=*/"true", + "Generate MemRef types with dynamic offset+strides by default.">, Option<"analysisFuzzerSeed", "analysis-fuzzer-seed", "unsigned", /*default=*/"0", "Analyze ops in random order with a given seed (fuzzer)">, diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp index 7d91229625cc..276950d4ef19 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -210,8 +210,10 @@ static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) { #endif } -static Value lookupBuffer(RewriterBase &rewriter, Value tensor) { - assert(tensor.getType().isa() && "unexpected non-tensor type"); +static Value lookupBuffer(RewriterBase &rewriter, Value tensor, + const BufferizationOptions &options) { + auto tensorType = tensor.getType().dyn_cast(); + assert(tensorType && "unexpected non-tensor type"); // Replace "%t = to_tensor %m" with %m. if (auto toTensorOp = tensor.getDefiningOp()) @@ -220,13 +222,7 @@ static Value lookupBuffer(RewriterBase &rewriter, Value tensor) { // Insert to_memref op. OpBuilder::InsertionGuard g(rewriter); setInsertionPointAfter(rewriter, tensor); - Type memrefType; - if (auto rankedTensorType = tensor.getType().dyn_cast()) { - memrefType = getDynamicMemRefType(rankedTensorType); - } else { - memrefType = getUnrankedMemRefType( - tensor.getType().cast().getElementType()); - } + Type memrefType = getMemRefType(tensorType, options); ensureToMemrefOpIsValid(tensor, memrefType); return rewriter.create(tensor.getLoc(), memrefType, tensor); @@ -242,7 +238,7 @@ FailureOr BufferizationState::getBuffer( Operation *op = opOperand.getOwner(); Location loc = op->getLoc(); Value operand = opOperand.get(); - Value operandBuffer = lookupBuffer(rewriter, operand); + Value operandBuffer = lookupBuffer(rewriter, operand, options); if (forceInPlace || isInPlace(opOperand)) return operandBuffer; @@ -513,27 +509,43 @@ bool bufferization::isFunctionArgument(Value value) { return isa(bbArg.getOwner()->getParentOp()); } -MemRefType -bufferization::getContiguousMemRefType(ShapedType shapedType, - MemRefLayoutAttrInterface layout, - Attribute memorySpace) { +MemRefType bufferization::getContiguousMemRefType(ShapedType shapedType, + Attribute memorySpace) { + MemRefLayoutAttrInterface layout = {}; return MemRefType::get(shapedType.getShape(), shapedType.getElementType(), layout, memorySpace); } -UnrankedMemRefType bufferization::getUnrankedMemRefType(Type elementType, - Attribute memorySpace) { - return UnrankedMemRefType::get(elementType, memorySpace); -} +BaseMemRefType bufferization::getMemRefType(TensorType tensorType, + const BufferizationOptions &options, + MemRefLayoutAttrInterface layout, + Attribute memorySpace) { + // Case 1: Unranked memref type. + if (auto unrankedTensorType = tensorType.dyn_cast()) { + assert(!layout && "UnrankedTensorType cannot have a layout map"); + return UnrankedMemRefType::get(unrankedTensorType.getElementType(), + memorySpace); + } -MemRefType bufferization::getDynamicMemRefType(RankedTensorType tensorType, - unsigned addressSpace) { + // Case 2: Ranked memref type with specified layout. If fully dynamic layout + // maps are not requested, generate a type with `layout`, which is empty (no + // layout map) by default. + auto rankedTensorType = tensorType.cast(); + if (layout || !options.fullyDynamicLayoutMaps) { + return MemRefType::get(rankedTensorType.getShape(), + rankedTensorType.getElementType(), layout, + memorySpace); + } + + // Case 3: Ranked memref type with unspecified layout. Choose the most dynamic + // one. // TODO: address space decisions to connect with the actual alloc. int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset; - SmallVector dynamicStrides(tensorType.getRank(), + SmallVector dynamicStrides(rankedTensorType.getRank(), ShapedType::kDynamicStrideOrOffset); AffineMap stridedLayout = makeStridedLinearLayoutMap( - dynamicStrides, dynamicOffset, tensorType.getContext()); - return MemRefType::get(tensorType.getShape(), tensorType.getElementType(), - stridedLayout, addressSpace); + dynamicStrides, dynamicOffset, rankedTensorType.getContext()); + return MemRefType::get(rankedTensorType.getShape(), + rankedTensorType.getElementType(), stridedLayout, + memorySpace); } diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp index 0fe79862a69d..63ffb0932007 100644 --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -308,16 +308,15 @@ static FuncOp getCalledFunction(CallOpInterface callOp) { /// dynamic buffer type supported. /// A later pass across all CallOps in the module can decide whether to simplify /// the types of to version according to some cost model. -static FunctionType getBufferizedFunctionType(MLIRContext *ctx, - TypeRange argumentTypes, - TypeRange resultTypes) { - auto rewrite = [](Type t) -> Type { +static FunctionType +getBufferizedFunctionType(MLIRContext *ctx, TypeRange argumentTypes, + TypeRange resultTypes, + const BufferizationOptions &options) { + auto rewrite = [&](Type t) -> Type { // TODO: non-zero address space. // TODO: layout information if relevant. - if (auto rankedTensorType = t.dyn_cast()) - return getDynamicMemRefType(rankedTensorType); if (auto tensorType = t.dyn_cast()) - return getUnrankedMemRefType(tensorType.getElementType()); + return getMemRefType(tensorType, options); return t; }; auto argTypes = llvm::to_vector<4>(llvm::map_range(argumentTypes, rewrite)); @@ -398,7 +397,8 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp, return funcOp->emitError() << "cannot bufferize bodiless function that " << "returns a tensor"; FunctionType bufferizedFuncType = getBufferizedFunctionType( - funcOp.getContext(), funcOp.getType().getInputs(), TypeRange{}); + funcOp.getContext(), funcOp.getType().getInputs(), TypeRange{}, + state.getOptions()); funcOp.setType(bufferizedFuncType); return success(); } @@ -431,7 +431,8 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp, // 2. Rewrite the terminator without the inPlace bufferizable values. ValueRange retValues{returnValues}; FunctionType bufferizedFuncType = getBufferizedFunctionType( - funcOp.getContext(), funcOp.getType().getInputs(), retValues.getTypes()); + funcOp.getContext(), funcOp.getType().getInputs(), retValues.getTypes(), + state.getOptions()); OpBuilder b(returnOp); b.create(returnOp.getLoc(), returnValues); returnOp->erase(); @@ -822,7 +823,7 @@ struct CallOpInterface // Get the bufferized FunctionType for funcOp or construct it if not yet // available. FunctionType bufferizedFuncType = getBufferizedFunctionType( - funcOp.getContext(), argumentTypes, resultTypes); + funcOp.getContext(), argumentTypes, resultTypes, state.getOptions()); // 3. Rewrite tensor operands as memrefs based on `bufferizedFuncType`. for (OpOperand &opOperand : callOp->getOpOperands()) { diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp index 87dd5b09773d..cc8517f8f119 100644 --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp @@ -74,11 +74,8 @@ struct ExecuteRegionOpInterface // Compute new result types. SmallVector newResultTypes; for (Type type : executeRegionOp->getResultTypes()) { - if (auto rankedTensorType = type.dyn_cast()) { - newResultTypes.push_back(getDynamicMemRefType(rankedTensorType)); - } else if (auto tensorType = type.dyn_cast()) { - newResultTypes.push_back( - getUnrankedMemRefType(tensorType.getElementType())); + if (auto tensorType = type.dyn_cast()) { + newResultTypes.push_back(getMemRefType(tensorType, state.getOptions())); } else { newResultTypes.push_back(type); } @@ -186,11 +183,8 @@ struct IfOpInterface // Compute new types of the bufferized scf.if op. SmallVector newTypes; for (Type returnType : ifOp->getResultTypes()) { - if (returnType.isa()) { - assert(returnType.isa() && - "unsupported unranked tensor"); - newTypes.push_back( - getDynamicMemRefType(returnType.cast())); + if (auto tensorType = returnType.dyn_cast()) { + newTypes.push_back(getMemRefType(tensorType, state.getOptions())); } else { newTypes.push_back(returnType); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp index 12d43300aacf..3b3b1e4e76ec 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp @@ -120,6 +120,7 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() { options->allowUnknownOps = allowUnknownOps; options->analysisFuzzerSeed = analysisFuzzerSeed; options->createDeallocs = createDeallocs; + options->fullyDynamicLayoutMaps = fullyDynamicLayoutMaps; options->printConflicts = printConflicts; options->testAnalysisOnly = testAnalysisOnly; diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp index 1c1226b45168..f3c9fb5aeb48 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -65,14 +65,8 @@ struct CastOpInterface layout = rankedMemRefType.getLayout(); // Compute the new memref type. - Type resultMemRefType; - if (resultTensorType.isa()) { - resultMemRefType = - getContiguousMemRefType(resultTensorType, layout, memorySpace); - } else { - resultMemRefType = - getUnrankedMemRefType(resultTensorType.getElementType(), memorySpace); - } + Type resultMemRefType = getMemRefType(resultTensorType, state.getOptions(), + layout, memorySpace); // Replace the op with a memref.cast. assert(memref::CastOp::areCastCompatible(resultBuffer->getType(), @@ -263,8 +257,7 @@ struct FromElementsOpInterface Location loc = op->getLoc(); auto tensorType = fromElementsOp.getType().cast(); auto shape = tensorType.getShape(); - MemRefType resultType = - MemRefType::get(tensorType.getShape(), tensorType.getElementType()); + MemRefType resultType = getContiguousMemRefType(tensorType); FailureOr maybeBuffer = createAlloc(rewriter, loc, resultType, {}, /*deallocMemref=*/state.getOptions().createDeallocs, diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir index aadbeaff86ff..ea1251fc080b 100644 --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir @@ -1,5 +1,8 @@ // RUN: mlir-opt %s -allow-unregistered-dialect -linalg-comprehensive-module-bufferize="allow-return-memref allow-unknown-ops" -split-input-file | FileCheck %s +// Test bufferization using memref types that have no layout map. +// RUN: mlir-opt %s -allow-unregistered-dialect -linalg-comprehensive-module-bufferize="allow-return-memref allow-unknown-ops fully-dynamic-layout-maps=0" -split-input-file | FileCheck %s --check-prefix=CHECK-NO-LAYOUT-MAP + // Run fuzzer with different seeds. // RUN: mlir-opt %s -allow-unregistered-dialect -linalg-comprehensive-module-bufferize="allow-return-memref test-analysis-only analysis-fuzzer-seed=23" -split-input-file -o /dev/null // RUN: mlir-opt %s -allow-unregistered-dialect -linalg-comprehensive-module-bufferize="allow-return-memref test-analysis-only analysis-fuzzer-seed=59" -split-input-file -o /dev/null @@ -8,20 +11,28 @@ // RUN: mlir-opt %s -allow-unregistered-dialect -test-comprehensive-function-bufferize="dialect-filter=tensor allow-unknown-ops allow-return-memref" -canonicalize -split-input-file | FileCheck %s --check-prefix=CHECK-TENSOR // RUN: mlir-opt %s -allow-unregistered-dialect -test-comprehensive-function-bufferize="dialect-filter=scf allow-unknown-ops allow-return-memref" -canonicalize -split-input-file | FileCheck %s --check-prefix=CHECK-SCF +// CHECK: #[[$MAP:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> + // CHECK-LABEL: func @use_of_unknown_op_1( -// CHECK-SAME: %[[m1:.*]]: memref +// CHECK-NO-LAYOUT-MAP-LABEL: func @use_of_unknown_op_1( +// CHECK-NO-LAYOUT-MAP-SAME: %[[m1:.*]]: memref) func @use_of_unknown_op_1(%t1: tensor {linalg.inplaceable = true}) -> vector<5xf32> { // ToTensorOp is generated because the function is bufferized and has a // memref block argument. - // CHECK: %[[m1_tensor:.*]] = bufferization.to_tensor %[[m1]] + // CHECK: %[[m1_tensor:.*]] = bufferization.to_tensor %[[m1]] : memref // CHECK: %[[dummy:.*]] = "test.dummy_op"(%[[m1_tensor]]) + // CHECK-NO-LAYOUT-MAP: %[[m1_tensor:.*]] = bufferization.to_tensor %[[m1]] : memref + // CHECK-NO-LAYOUT-MAP: %[[dummy:.*]] = "test.dummy_op"(%[[m1_tensor]]) %0 = "test.dummy_op"(%t1) : (tensor) -> tensor %idx = arith.constant 0 : index %cst = arith.constant 0.0 : f32 - // CHECK: %[[dummy_memref:.*]] = bufferization.to_memref %[[dummy]] - // CHECK: vector.transfer_read %[[dummy_memref]] + // CHECK: %[[dummy_memref:.*]] = bufferization.to_memref %[[dummy]] : memref + // CHECK: vector.transfer_read %[[dummy_memref]][%{{.*}}], %{{.*}} : memref + // CHECK-NO-LAYOUT-MAP: %[[dummy_memref:.*]] = bufferization.to_memref %[[dummy]] : memref + // CHECK-NO-LAYOUT-MAP: vector.transfer_read %[[dummy_memref]][%{{.*}}], %{{.*}} : memref %1 = vector.transfer_read %0[%idx], %cst : tensor, vector<5xf32> return %1 : vector<5xf32> } diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir index c4ea9a48b8ec..42734ae1e9ad 100644 --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -5,7 +5,11 @@ // RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="allow-return-memref test-analysis-only analysis-fuzzer-seed=59" -split-input-file -o /dev/null // RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="allow-return-memref test-analysis-only analysis-fuzzer-seed=91" -split-input-file -o /dev/null +// Test bufferization using memref types that have no layout map. +// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="allow-return-memref fully-dynamic-layout-maps=0" -split-input-file | FileCheck %s --check-prefix=CHECK-NO-LAYOUT-MAP + // CHECK-LABEL: func @transfer_read(%{{.*}}: memref) -> vector<4xf32> { +// CHECK-NO-LAYOUT-MAP-LABEL: func @transfer_read(%{{.*}}: memref) -> vector<4xf32> func @transfer_read( %A : tensor {linalg.inplaceable = false}) -> (vector<4xf32>) @@ -26,6 +30,7 @@ func @transfer_read( // CHECK-LABEL: func @fill_inplace( // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref +// CHECK-NO-LAYOUT-MAP-LABEL: func @fill_inplace(%{{.*}}: memref) { func @fill_inplace( %A : tensor {linalg.inplaceable = true}) -> tensor @@ -63,6 +68,7 @@ func @tensor_extract(%A : tensor {linalg.inplaceable = false}) -> (f32) { /// No linalg.inplaceable flag, must allocate. // CHECK-LABEL: func @not_inplace( // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref) -> memref { +// CHECK-NO-LAYOUT-MAP-LABEL: func @not_inplace(%{{.*}}: memref) -> memref func @not_inplace( %A : tensor {linalg.inplaceable = false}) -> tensor @@ -86,6 +92,7 @@ func @not_inplace( // CHECK-LABEL: func @not_inplace // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref) { +// CHECK-NO-LAYOUT-MAP-LABEL: func @not_inplace(%{{.*}}: memref) { func @not_inplace( %A : tensor {linalg.inplaceable = true}) -> tensor diff --git a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp index a9b5ab206d42..6e72f77ee349 100644 --- a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp @@ -91,6 +91,10 @@ struct TestComprehensiveFunctionBufferize *this, "dialect-filter", llvm::cl::desc("Bufferize only ops from the specified dialects"), llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; + Option fullyDynamicLayoutMaps{ + *this, "fully-dynamic-layout-maps", + llvm::cl::desc("Use fully dynamic layout maps on memref types"), + llvm::cl::init(true)}; Option createDeallocs{ *this, "create-deallocs", llvm::cl::desc("Specify if buffers should be deallocated"), @@ -108,6 +112,7 @@ void TestComprehensiveFunctionBufferize::runOnOperation() { options->allowUnknownOps = allowUnknownOps; options->testAnalysisOnly = testAnalysisOnly; options->analysisFuzzerSeed = analysisFuzzerSeed; + options->fullyDynamicLayoutMaps = fullyDynamicLayoutMaps; options->createDeallocs = createDeallocs; if (dialectFilter.hasValue()) {