From 72040bf7c8f24d8fb66d77bdae289a3040943555 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Thu, 7 Nov 2019 06:32:39 -0800 Subject: [PATCH] Update Linalg to use std.view Now that a view op has graduated to the std dialect, we can update Linalg to use it and remove ops that have become obsolete. As a byproduct, the linalg buffer and associated ops can also disappear. PiperOrigin-RevId: 279073591 --- .../mlir/Dialect/Linalg/IR/LinalgOps.td | 147 ------------ mlir/include/mlir/Dialect/Linalg/Passes.h | 3 +- .../mlir/Dialect/Linalg/Utils/Intrinsics.h | 7 - .../include/mlir/Dialect/Linalg/Utils/Utils.h | 7 +- mlir/include/mlir/EDSC/Intrinsics.h | 1 + .../StandardToLLVM/ConvertStandardToLLVM.cpp | 12 +- .../Linalg/Analysis/DependenceAnalysis.cpp | 2 +- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 151 +----------- .../Linalg/Transforms/LowerToLLVMDialect.cpp | 216 +----------------- .../Dialect/Linalg/Transforms/Promotion.cpp | 67 ++++-- mlir/lib/Dialect/StandardOps/Ops.cpp | 1 + mlir/test/Dialect/Linalg/invalid.mlir | 70 ------ mlir/test/Dialect/Linalg/llvm.mlir | 77 +------ mlir/test/Dialect/Linalg/loops.mlir | 55 ++--- mlir/test/Dialect/Linalg/promote.mlir | 35 +-- mlir/test/Dialect/Linalg/roundtrip.mlir | 96 +------- .../linalg_integration_test.mlir | 55 +++-- 17 files changed, 149 insertions(+), 853 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td index 00185416b736..b7aca2173894 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -44,108 +44,6 @@ class Linalg_Op traits = []> : let parser = [{ return ::parse$cppClass(parser, result); }]; } -def BufferAllocOp : - Linalg_Op<"buffer_alloc">, - Arguments<(ins Variadic:$size, OptionalAttr:$alignment)>, - Results<(outs Buffer)> { - let summary = "buffer allocation operation"; - let description = [{ - The "buffer_alloc" op creates a 1-D linalg.buffer of the specified type, - upon which a base view can be laid out to give it indexing semantics. - "buffer_alloc" takes a single argument, the size of the buffer to allocate - (in number of elements). - An optional alignment attribute may be specified in which case the actual - underlying allocation size may be increased. The base pointer is guaranteed - to be a multiple of `alignment`. Such an alignment must be a positive power - of 2. - - Examples: - - %0 = linalg.buffer_alloc(%arg0) : !linalg.buffer - - %1 = linalg.buffer_alloc(%arg0) { alignment = 16 } : - !linalg.buffer - - The size argument may be omitted if it is statically known, in which case it - must be reflected in the type. - - Example: - - %0 = linalg.buffer_alloc() : !linalg.buffer<4xf32> - }]; - let builders = [ - OpBuilder< - "Builder *b, OperationState &result, BufferType bufferType", [{ - result.addTypes(bufferType); - }]>, - OpBuilder< - "Builder *b, OperationState &result, BufferType bufferType, " - "unsigned alignment", [{ - build(b, result, bufferType); - if (alignment != 0) - result.addAttribute(BufferAllocOp::getAlignmentAttrName(), - b->getI64IntegerAttr(alignment)); - }]>, - OpBuilder< - "Builder *b, OperationState &result, BufferType bufferType, " - "Value *size, unsigned alignment", [{ - if (alignment == 0) - return build(b, result, bufferType, size); - build(b, result, bufferType, size, b->getI64IntegerAttr(alignment)); - }]>, - OpBuilder< - "Builder *b, OperationState &result, BufferType bufferType, Value *size", - [{ - result.addOperands(size); - result.addTypes(bufferType); - }]> - ]; - let extraClassDeclaration = [{ - static StringRef getAlignmentAttrName() { return "alignment"; } - BufferType getBufferType() { return getType().cast(); } - Type getElementType() { return getBufferType().getElementType(); } - }]; -} - -def BufferDeallocOp : - Linalg_Op<"buffer_dealloc">, - Arguments<(ins Buffer:$buffer)>, - Results<(outs)> { - let summary = "buffer allocation operation"; - let description = [{ - The "buffer_dealloc" op frees a 1-D linalg.buffer of the specified type. - - Example: - - linalg.buffer_dealloc %0 : !linalg.buffer - }]; - let extraClassDeclaration = [{ - BufferType getBufferType() { - return buffer()->getType().cast(); - } - }]; - - // Fully specified by traits. - let verifier = ?; -} - -def BufferSizeOp : - Linalg_Op<"buffer_size", [NoSideEffect]>, - Arguments<(ins Buffer:$buffer)>, - Results<(outs Index)> { - let summary = "buffer size operation"; - let description = [{ - The "linalg.buffer_size" operation takes a linalg.buffer and returns an - "index". - - Example: - - %0 = linalg.buffer_size %arg0 : !linalg.buffer - }]; - // Fully specified by traits. - let verifier = ?; -} - def RangeOp : Linalg_Op<"range", [NoSideEffect]>, Arguments<(ins Index:$min, Index:$max, Index:$step)>, @@ -329,51 +227,6 @@ def TransposeOp : Linalg_Op<"transpose", [NoSideEffect]>, }]; } -def ViewOp : Linalg_Op<"view", [NoSideEffect]>, - Arguments<(ins Buffer:$buffer, Variadic:$ranges)>, - Results<(outs AnyStridedMemRef)> { - let summary = "view operation"; - let description = [{ - The "linalg.view" op produces a strided memref which is a multi-dimensional - range abstraction on top of an underlying linalg.buffer. This gives an - indexing structure to an otherwise non-indexable linalg.buffer. - - A "linalg.view" takes a buffer and a variadic number of ranges and produces - a `view` of rank the number of ranges. The elemental type may not match the - buffer element type: - - Example: - - %1 = linalg.buffer_alloc %0 : !linalg.buffer - %2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range - %3 = linalg.view %1[%2, %2] : - memref, stride_specification> - }]; - - let builders = [OpBuilder< - "Builder *b, OperationState &result, Value *buffer, " - "ArrayRef ranges, Type resultType = Type(), " - "ArrayRef attrs = {}">]; - - let verifier = [{ - if (getViewType().getRank() != llvm::size(ranges())) - return emitOpError("the view rank must be the number of its ranges"); - return success(); - }]; - - let extraClassDeclaration = [{ - enum { FirstIndexingOperand = 1 }; - unsigned getRank() { return getViewType().getRank(); } - Type getElementType() { return getViewType().getElementType(); } - MemRefType getViewType() { return getType().cast(); } - /// Get the underlying indexing at a given rank. - Value *getRange(unsigned rank) { - assert(rank < getRank() && "rank overflow"); - return *(ranges().begin() + rank); - } - }]; -} - def YieldOp : Linalg_Op<"yield", [NativeOpTrait<"IsTerminator">]>, Arguments<(ins Variadic:$values)> { let summary = "Linalg yield operation"; diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h index fb68c0ae9c37..8a01fe4d8bc7 100644 --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -36,7 +36,8 @@ std::unique_ptr> createLinalgFusionPass(); std::unique_ptr> createLinalgTilingPass(ArrayRef tileSizes = {}); -std::unique_ptr> createLinalgPromotionPass(); +std::unique_ptr> +createLinalgPromotionPass(bool dynamicBuffers); std::unique_ptr> createLowerLinalgToLoopsPass(); diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Intrinsics.h b/mlir/include/mlir/Dialect/Linalg/Utils/Intrinsics.h index 1c6bb68eb886..5a815ba158e5 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/Intrinsics.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Intrinsics.h @@ -22,22 +22,15 @@ namespace mlir { namespace linalg { -class BufferAllocOp; -class BufferDeallocOp; class CopyOp; class FillOp; class RangeOp; class SliceOp; -class ViewOp; namespace intrinsics { -using buffer_alloc = mlir::edsc::intrinsics::ValueBuilder; -using buffer_dealloc = - mlir::edsc::intrinsics::OperationBuilder; using copy = mlir::edsc::intrinsics::OperationBuilder; using fill = mlir::edsc::intrinsics::OperationBuilder; using range = mlir::edsc::intrinsics::ValueBuilder; using slice = mlir::edsc::intrinsics::ValueBuilder; -using view = mlir::edsc::intrinsics::ValueBuilder; } // namespace intrinsics } // namespace linalg } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h index 0401d6987aa8..5fd3d97711b9 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -173,9 +173,10 @@ struct PromotionInfo { /// /// Returns a list of PromotionInfo which hold the promoted buffer and the /// full and partial views indexing into the buffer. -llvm::SmallVector promoteSubViews(OpBuilder &b, Location loc, - ArrayRef subViews, - OperationFolder *folder); +llvm::SmallVector +promoteSubViews(OpBuilder &b, Location loc, ArrayRef subViews, + bool promoteSubViews = false, + OperationFolder *folder = nullptr); /// Returns all the operands of `linalgOp` that are not views. /// Asserts that these operands are value types to allow transformations like diff --git a/mlir/include/mlir/EDSC/Intrinsics.h b/mlir/include/mlir/EDSC/Intrinsics.h index cd48d9d7488d..e76bc2fea087 100644 --- a/mlir/include/mlir/EDSC/Intrinsics.h +++ b/mlir/include/mlir/EDSC/Intrinsics.h @@ -216,6 +216,7 @@ using std_load = ValueBuilder; using std_store = OperationBuilder; using subi = ValueBuilder; using vector_type_cast = ValueBuilder; +using view = ValueBuilder; /// Branches into the mlir::Block* captured by BlockHandle `b` with `operands`. /// diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index 89bf07f27cda..1584bd4d4ed4 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -1430,10 +1430,7 @@ struct ViewOpLowering : public LLVMLegalizationPattern { SmallVector strides; auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset); if (failed(successStrides)) - return op->emitWarning("Cannot cast to non-strided shape"), - matchFailure(); - if (strides.back() != 1) - return op->emitWarning("Cannot cast to non-contiguous shape"), + return op->emitWarning("cannot cast to non-strided shape"), matchFailure(); // Create the descriptor. @@ -1466,7 +1463,14 @@ struct ViewOpLowering : public LLVMLegalizationPattern { rewriter.getI64ArrayAttr( LLVMTypeConverter::kOffsetPosInMemRefDescriptor)); + // Early exit for 0-D corner case. + if (viewMemRefType.getRank() == 0) + return rewriter.replaceOp(op, desc), matchSuccess(); + // Update sizes and strides. + if (strides.back() != 1) + return op->emitWarning("cannot cast to non-contiguous shape"), + matchFailure(); Value *stride = nullptr, *nextSize = nullptr; for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { // Update size. diff --git a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp index 9e57b7bb9deb..8772171a5c36 100644 --- a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp +++ b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp @@ -74,7 +74,7 @@ Value *Aliases::find(Value *v) { return it.first->second; } if (auto view = dyn_cast_or_null(v->getDefiningOp())) { - auto it = aliases.insert(std::make_pair(v, view.buffer())); + auto it = aliases.insert(std::make_pair(v, view.source())); return it.first->second; } if (auto view = dyn_cast_or_null(v->getDefiningOp())) { diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index e34f5223996f..9a160f534fe9 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -50,102 +50,6 @@ using namespace mlir::linalg; // LinalgOps.td), we define an overloaded `print` function and a // parse`className` function. -//===----------------------------------------------------------------------===// -// BufferAllocOp -//===----------------------------------------------------------------------===// - -static void print(OpAsmPrinter &p, BufferAllocOp op) { - p << op.getOperationName() << " "; - if (!llvm::empty(op.size())) - p << *op.getOperand(0); - if (op.alignment().hasValue() && op.alignment()->getSExtValue() != 0) - p.printOptionalAttrDict(op.getAttrs()); - else - p.printOptionalAttrDict(op.getAttrs(), - BufferAllocOp::getAlignmentAttrName()); - p << " : " << op.getBufferType(); -} - -static ParseResult parseBufferAllocOp(OpAsmParser &parser, - OperationState &result) { - SmallVector sizeInfo; - BufferType bufferType; - auto indexTy = parser.getBuilder().getIndexType(); - if (parser.parseOperandList(sizeInfo) || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(bufferType)) - return failure(); - if (sizeInfo.empty()) - return parser.addTypeToList(bufferType, result.types); - return failure(parser.resolveOperands(sizeInfo, indexTy, result.operands) || - parser.addTypeToList(bufferType, result.types)); -} - -static LogicalResult verify(BufferAllocOp op) { - if (!op.getBufferType().hasConstantSize()) { - if (llvm::size(op.size()) != 1) - return op.emitOpError("expected one index operand"); - } else { // op.getBufferType().hasConstantSize() - if (!llvm::empty(op.size())) - return op.emitOpError("expected zero operand"); - if (op.getBufferType().getBufferSize().getValue() <= 0) - return op.emitOpError("expected nonnegative static buffer size"); - } - if (op.alignment().hasValue()) { - auto align = op.alignment().getValue(); - if (align.getSExtValue() < 0) - return op.emitOpError("expected positive alignment"); - if (!llvm::isPowerOf2_64(align.getZExtValue())) - return op.emitOpError("expected power of 2 alignment"); - } - if (!TensorType::isValidElementType(op.getElementType())) - return op.emitOpError("expected valid buffer element type"); - return success(); -} - -//===----------------------------------------------------------------------===// -// BufferDeallocOp -//===----------------------------------------------------------------------===// - -static void print(OpAsmPrinter &p, BufferDeallocOp op) { - p << op.getOperationName() << " " << *op.buffer(); - p.printOptionalAttrDict(op.getAttrs()); - p << " : " << op.getBufferType(); -} - -static ParseResult parseBufferDeallocOp(OpAsmParser &parser, - OperationState &result) { - OpAsmParser::OperandType bufferInfo; - BufferType bufferType; - if (parser.parseOperand(bufferInfo) || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(bufferType)) - return failure(); - return parser.resolveOperands(bufferInfo, bufferType, result.operands); -} - -//===----------------------------------------------------------------------===// -// BufferSizeOp -//===----------------------------------------------------------------------===// - -static void print(OpAsmPrinter &p, BufferSizeOp op) { - p << op.getOperationName() << " " << *op.buffer(); - p.printOptionalAttrDict(op.getAttrs()); - p << " : " << op.buffer()->getType(); -} - -static ParseResult parseBufferSizeOp(OpAsmParser &parser, - OperationState &result) { - OpAsmParser::OperandType op; - Type type; - return failure( - parser.parseOperand(op) || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(type) || - parser.resolveOperand(op, type, result.operands) || - parser.addTypeToList(parser.getBuilder().getIndexType(), result.types)); -} - //===----------------------------------------------------------------------===// // GenericOps //===----------------------------------------------------------------------===// @@ -426,7 +330,7 @@ static LogicalResult verify(SliceOp op) { unsigned rank = op.getBaseViewRank(); if (rank != llvm::size(op.indexings())) return op.emitOpError("expected ") - << op.getRank() << " indexings, got " << llvm::size(op.indexings()); + << rank << " indexings, got " << llvm::size(op.indexings()); unsigned index = 0; for (auto indexing : op.indexings()) { if (indexing->getType().isa()) @@ -562,59 +466,6 @@ static ParseResult parseTransposeOp(OpAsmParser &parser, parser.addTypeToList(type, result.types)); } -//===----------------------------------------------------------------------===// -// ViewOp -//===----------------------------------------------------------------------===// -void mlir::linalg::ViewOp::build(Builder *b, OperationState &result, - Value *buffer, ArrayRef ranges, - Type resultType, - ArrayRef attrs) { - // If the result type is not specified, assume sizes are fully dynamic. - // Strides are set to match an empty layout map which means "contiguous view". - if (!resultType) { - auto rank = ranges.size(); - SmallVector sizes(rank, -1); - Type elementType = buffer->getType().cast().getElementType(); - resultType = MemRefType::get(sizes, elementType, {}, 0); - } - build(b, result, resultType, buffer, ranges); - result.addAttributes(attrs); -} - -static void print(OpAsmPrinter &p, mlir::linalg::ViewOp op) { - p << op.getOperationName() << " " << *op.buffer() << "["; - interleaveComma(op.ranges(), p, [&](Value *v) { p << *v; }); - p << "] "; - p.printOptionalAttrDict(op.getAttrs()); - p << " : " << op.buffer()->getType() << " -> " << op.getType(); -} - -static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) { - OpAsmParser::OperandType bufferInfo; - SmallVector rangesInfo; - Type bType, vType; - if (parser.parseOperand(bufferInfo) || - parser.parseOperandList(rangesInfo, OpAsmParser::Delimiter::Square) || - parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || - parser.parseType(bType) || parser.parseArrow() || - parser.parseType(vType)) { - return failure(); - } - - MemRefType memRefType = vType.dyn_cast(); - if (!memRefType) - return parser.emitError(parser.getNameLoc(), "expected memref type"); - if (static_cast(memRefType.getRank()) != rangesInfo.size()) - return parser.emitError(parser.getNameLoc(), "expected ") - << memRefType.getRank() << " ranges"; - return failure( - parser.resolveOperand(bufferInfo, bType, result.operands) || - (!rangesInfo.empty() && - parser.resolveOperands(rangesInfo, RangeType::get(vType.getContext()), - result.operands)) || - parser.addTypeToList(memRefType, result.types)); -} - //===----------------------------------------------------------------------===// // YieldOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp index 3de6dc6b5010..6b51e039a5bd 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp @@ -166,144 +166,6 @@ public: }; } // namespace -// BufferAllocOp creates a new `!linalg.buffer` value. -class BufferAllocOpConversion : public LLVMOpLowering { -public: - explicit BufferAllocOpConversion(MLIRContext *context, - LLVMTypeConverter &lowering_) - : LLVMOpLowering(BufferAllocOp::getOperationName(), context, lowering_) {} - - PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - auto indexType = IndexType::get(op->getContext()); - auto voidPtrTy = - LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo(); - auto int64Ty = lowering.convertType(rewriter.getIntegerType(64)) - .cast(); - // Insert the `malloc` declaration if it is not already present. - auto module = op->getParentOfType(); - auto mallocFunc = module.lookupSymbol("malloc"); - if (!mallocFunc) { - OpBuilder moduleBuilder(op->getParentOfType().getBodyRegion()); - mallocFunc = moduleBuilder.create( - rewriter.getUnknownLoc(), "malloc", - LLVM::LLVMType::getFunctionTy(voidPtrTy, int64Ty, - /*isVarArg=*/false)); - } - - // Get MLIR types for injecting element pointer. - auto allocOp = cast(op); - auto elementType = allocOp.getElementType(); - uint64_t elementSize = 0; - if (auto vectorType = elementType.dyn_cast()) - elementSize = vectorType.getNumElements() * - llvm::divideCeil(vectorType.getElementTypeBitWidth(), 8); - else - elementSize = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8); - auto bufferType = allocOp.getBufferType(); - auto elementPtrType = getPtrToElementType(bufferType, lowering); - auto bufferDescriptorTy = convertLinalgType(bufferType, lowering); - - // Emit IR for creating a new buffer descriptor with an underlying malloc. - edsc::ScopedContext context(rewriter, op->getLoc()); - auto constantSize = bufferType.getBufferSize(); - Value *size = - constantSize - ? constant(int64Ty, IntegerAttr::get(indexType, *constantSize)) - .getValue() - : operands[0]; - Value *allocSize = - mul(size, constant(int64Ty, IntegerAttr::get(indexType, elementSize))); - Value *one = nullptr, *align = nullptr; - if (allocOp.alignment().hasValue()) { - one = constant(int64Ty, IntegerAttr::get(indexType, 1)); - align = - constant(int64Ty, rewriter.getIntegerAttr( - rewriter.getIndexType(), - allocOp.alignment().getValue().getSExtValue())); - allocSize = sub(add(allocSize, align), one); - } - - Value *allocated = - llvm_call(voidPtrTy, rewriter.getSymbolRefAttr(mallocFunc), allocSize) - .getOperation() - ->getResult(0); - Value *data = allocated; - if (allocOp.alignment().hasValue()) { - // offset = (align - (ptr % align))% align - Value *offset = - urem(sub(align, urem(ptrtoint(int64Ty, allocated), align)), align); - data = gep(voidPtrTy, allocated, offset); - } - data = bitcast(elementPtrType, data); - Value *desc = llvm_undef(bufferDescriptorTy); - desc = insertvalue(bufferDescriptorTy, desc, allocated, - rewriter.getI64ArrayAttr(kBasePtrPosInBuffer)); - desc = insertvalue(bufferDescriptorTy, desc, data, - rewriter.getI64ArrayAttr(kPtrPosInBuffer)); - desc = insertvalue(bufferDescriptorTy, desc, size, - rewriter.getI64ArrayAttr(kSizePosInBuffer)); - rewriter.replaceOp(op, desc); - return matchSuccess(); - } -}; - -// BufferDeallocOp creates no value. -class BufferDeallocOpConversion : public LLVMOpLowering { -public: - explicit BufferDeallocOpConversion(MLIRContext *context, - LLVMTypeConverter &lowering_) - : LLVMOpLowering(BufferDeallocOp::getOperationName(), context, - lowering_) {} - - PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - auto voidTy = LLVM::LLVMType::getVoidTy(lowering.getDialect()); - auto voidPtrTy = - LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo(); - // Insert the `free` declaration if it is not already present. - auto module = op->getParentOfType(); - auto freeFunc = module.lookupSymbol("free"); - if (!freeFunc) { - OpBuilder moduleBuilder(op->getParentOfType().getBodyRegion()); - freeFunc = moduleBuilder.create( - rewriter.getUnknownLoc(), "free", - LLVM::LLVMType::getFunctionTy(voidTy, voidPtrTy, - /*isVarArg=*/false)); - } - - // Emit MLIR for buffer_dealloc. - BufferDeallocOpOperandAdaptor adaptor(operands); - edsc::ScopedContext context(rewriter, op->getLoc()); - Value *base = extractvalue(voidPtrTy, adaptor.buffer(), - rewriter.getI64ArrayAttr(kBasePtrPosInBuffer)); - llvm_call(ArrayRef(), rewriter.getSymbolRefAttr(freeFunc), base); - rewriter.eraseOp(op); - return matchSuccess(); - } -}; - -// BufferSizeOp creates a new `index` value. -class BufferSizeOpConversion : public LLVMOpLowering { -public: - BufferSizeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) - : LLVMOpLowering(BufferSizeOp::getOperationName(), context, lowering_) {} - - PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - auto int64Ty = lowering.convertType(rewriter.getIntegerType(64)); - edsc::ScopedContext context(rewriter, op->getLoc()); - BufferSizeOpOperandAdaptor adaptor(operands); - rewriter.replaceOp( - op, {extractvalue(int64Ty, adaptor.buffer(), - rewriter.getI64ArrayAttr(kSizePosInBuffer))}); - return matchSuccess(); - } -}; - // RangeOp creates a new range descriptor. class RangeOpConversion : public LLVMOpLowering { public: @@ -480,78 +342,6 @@ public: } }; -/// Conversion pattern that transforms a linalg.view op into: -/// 1. A function entry `alloca` operation to allocate a ViewDescriptor. -/// 2. A load of the ViewDescriptor from the pointer allocated in 1. -/// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size -/// and stride. -/// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer. -/// The linalg.view op is replaced by the alloca'ed pointer. -class ViewOpConversion : public LLVMOpLowering { -public: - explicit ViewOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) - : LLVMOpLowering(mlir::linalg::ViewOp::getOperationName(), context, - lowering_) {} - - PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - mlir::linalg::ViewOpOperandAdaptor adaptor(operands); - - auto viewOp = cast(op); - BaseViewConversionHelper helper(op->getLoc(), viewOp.getViewType(), - rewriter, lowering); - LLVMType elementTy = helper.elementTy, int64Ty = helper.int64Ty; - Value *desc = helper.desc; - - Value *bufferDescriptor = adaptor.buffer(); - auto bufferTy = getPtrToElementType( - viewOp.buffer()->getType().cast(), lowering); - - edsc::ScopedContext context(rewriter, op->getLoc()); - - // Copy the buffer pointer from the old descriptor to the new one. - Value *bufferAsViewElementType = - bitcast(elementTy, extractvalue(bufferTy, bufferDescriptor, - helper.pos(kPtrPosInBuffer))); - desc = - insertvalue(desc, bufferAsViewElementType, helper.pos(kPtrPosInView)); - - // Zero base offset. - auto indexTy = rewriter.getIndexType(); - Value *baseOffset = constant(int64Ty, IntegerAttr::get(indexTy, 0)); - desc = insertvalue(desc, baseOffset, helper.pos(kOffsetPosInView)); - - // Corner case, no sizes or stride: early return the descriptor. - if (helper.zeroDMemRef) { - rewriter.replaceOp(op, desc); - return matchSuccess(); - } - - // Compute and insert view sizes (max - min along the range). - int numRanges = llvm::size(viewOp.ranges()); - Value *runningStride = constant(int64Ty, IntegerAttr::get(indexTy, 1)); - for (int i = numRanges - 1; i >= 0; --i) { - // Update stride. - Value *rangeDescriptor = operands[1 + i]; - Value *step = extractvalue(int64Ty, rangeDescriptor, helper.pos(2)); - Value *stride = mul(runningStride, step); - desc = insertvalue(desc, stride, helper.pos({kStridePosInView, i})); - // Update size. - Value *min = extractvalue(int64Ty, rangeDescriptor, helper.pos(0)); - Value *max = extractvalue(int64Ty, rangeDescriptor, helper.pos(1)); - Value *size = sub(max, min); - desc = insertvalue(desc, size, helper.pos({kSizePosInView, i})); - // Update stride for the next dimension. - if (i > 0) - runningStride = mul(runningStride, max); - } - - rewriter.replaceOp(op, desc); - return matchSuccess(); - } -}; - // YieldOp produces and LLVM::ReturnOp. class YieldOpConversion : public LLVMOpLowering { public: @@ -731,10 +521,8 @@ static void populateLinalgToLLVMConversionPatterns(LinalgTypeConverter &converter, OwningRewritePatternList &patterns, MLIRContext *ctx) { - patterns.insert( - ctx, converter); + patterns.insert(ctx, converter); } namespace { diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp index c9b7435a76e9..a23e68dc8f32 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -49,17 +49,26 @@ using llvm::SetVector; #define DEBUG_TYPE "linalg-promotion" +static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options"); +static llvm::cl::opt clPromoteDynamic( + "test-linalg-promote-dynamic", + llvm::cl::desc("Test generation of dynamic promoted buffers"), + llvm::cl::cat(clOptionsCategory), llvm::cl::init(false)); + static AffineMap getAffineDifferenceMap(MLIRContext *context) { AffineExpr d0(getAffineDimExpr(0, context)), d1(getAffineDimExpr(1, context)); return AffineMap::get(2, 0, {d0 - d1}); } -// TODO(ntv): replace this with 1-D memref alloc once there is an std.view op. -static Value *allocBuffer(Type elementType, Value *size) { - if (auto cst = dyn_cast_or_null(size->getDefiningOp())) - return buffer_alloc( - BufferType::get(size->getContext(), elementType, cst.getValue())); - return buffer_alloc(BufferType::get(size->getContext(), elementType), size); +static Value *allocBuffer(Type elementType, Value *size, bool dynamicBuffers) { + auto *ctx = size->getContext(); + auto width = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8); + if (!dynamicBuffers) + if (auto cst = dyn_cast_or_null(size->getDefiningOp())) + return alloc( + MemRefType::get(width * cst.getValue(), IntegerType::get(8, ctx))); + Value *mul = muli(constant_index(width), size); + return alloc(MemRefType::get(-1, IntegerType::get(8, ctx)), mul); } // Performs promotion of a `subView` into a local buffer of the size of the @@ -81,6 +90,7 @@ static Value *allocBuffer(Type elementType, Value *size) { // by a partial `copy` op. static PromotionInfo promoteFullTileBuffer(OpBuilder &b, Location loc, SubViewOp subView, + bool dynamicBuffers, OperationFolder *folder) { auto zero = constant_index(folder, 0); auto one = constant_index(folder, 1); @@ -101,18 +111,21 @@ static PromotionInfo promoteFullTileBuffer(OpBuilder &b, Location loc, {rangeValue.max, rangeValue.min}, folder) .front(); allocSize = muli(folder, allocSize, d).getValue(); - fullRanges.push_back(range(folder, zero, d, one)); + fullRanges.push_back(d); partialRanges.push_back(range(folder, zero, dim(subView, rank), one)); } - auto *buffer = allocBuffer(viewType.getElementType(), allocSize); - auto fullLocalView = view(buffer, fullRanges); + SmallVector dynSizes(fullRanges.size(), -1); + auto *buffer = + allocBuffer(viewType.getElementType(), allocSize, dynamicBuffers); + auto fullLocalView = view( + MemRefType::get(dynSizes, viewType.getElementType()), buffer, fullRanges); auto partialLocalView = slice(fullLocalView, partialRanges); return PromotionInfo{buffer, fullLocalView, partialLocalView}; } SmallVector mlir::linalg::promoteSubViews(OpBuilder &b, Location loc, - ArrayRef subViews, + ArrayRef subViews, bool dynamicBuffers, OperationFolder *folder) { if (subViews.empty()) return {}; @@ -127,7 +140,8 @@ mlir::linalg::promoteSubViews(OpBuilder &b, Location loc, // TODO(ntv): support more cases than just float. if (!viewType.getElementType().isa()) continue; - auto promotionInfo = promoteFullTileBuffer(b, loc, subView, folder); + auto promotionInfo = + promoteFullTileBuffer(b, loc, subView, dynamicBuffers, folder); promotionInfoMap.insert(std::make_pair(subView.getResult(), promotionInfo)); res.push_back(promotionInfo); } @@ -157,12 +171,13 @@ mlir::linalg::promoteSubViews(OpBuilder &b, Location loc, } static void promoteSubViewOperands(LinalgOp op, SetVector subViews, + bool dynamicBuffers, OperationFolder *folder) { // 1. Promote the specified views and use them in the new op. OpBuilder b(op); ScopedContext scope(b, op.getLoc()); - auto promotedBufferAndViews = - promoteSubViews(b, op.getLoc(), subViews.getArrayRef(), folder); + auto promotedBufferAndViews = promoteSubViews( + b, op.getLoc(), subViews.getArrayRef(), dynamicBuffers, folder); SmallVector opViews; opViews.reserve(op.getNumInputsAndOutputs()); SmallVector, 8> writebackViews; @@ -197,13 +212,13 @@ static void promoteSubViewOperands(LinalgOp op, SetVector subViews, // 4. Dealloc local buffers. for (const auto &pi : promotedBufferAndViews) - buffer_dealloc(pi.buffer); + dealloc(pi.buffer); } -static void promoteSubViews(FuncOp f) { +static void promoteSubViews(FuncOp f, bool dynamicBuffers) { SmallVector toErase; OperationFolder folder(f.getContext()); - f.walk([&folder, &toErase](LinalgOp op) { + f.walk([dynamicBuffers, &folder, &toErase](LinalgOp op) { // TODO(ntv) some heuristic here to decide what to promote. Atm it is all or // nothing. SetVector subViews; @@ -211,7 +226,7 @@ static void promoteSubViews(FuncOp f) { if (auto sv = dyn_cast_or_null(it->getDefiningOp())) subViews.insert(sv); if (!subViews.empty()) { - promoteSubViewOperands(op, subViews, &folder); + promoteSubViewOperands(op, subViews, dynamicBuffers, &folder); toErase.push_back(op); } }); @@ -221,13 +236,23 @@ static void promoteSubViews(FuncOp f) { namespace { struct LinalgPromotionPass : public FunctionPass { - void runOnFunction() override { promoteSubViews(getFunction()); } + LinalgPromotionPass() = default; + LinalgPromotionPass(bool dynamicBuffers) : dynamicBuffers(dynamicBuffers) {} + + void runOnFunction() override { + promoteSubViews(getFunction(), dynamicBuffers); + } + + bool dynamicBuffers; }; } // namespace -std::unique_ptr> mlir::linalg::createLinalgPromotionPass() { - return std::make_unique(); +std::unique_ptr> +mlir::linalg::createLinalgPromotionPass(bool dynamicBuffers) { + return std::make_unique(dynamicBuffers); } static PassRegistration - pass("linalg-promote-subviews", "promote subview ops to local buffers"); + pass("linalg-promote-subviews", "promote subview ops to local buffers", [] { + return std::make_unique(clPromoteDynamic); + }); diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp index 161a6c409c70..82d4324dff84 100644 --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -34,6 +34,7 @@ #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/raw_ostream.h" + using namespace mlir; //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir index 9f1bb75364bd..3e4c542f223a 100644 --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -1,49 +1,5 @@ // RUN: mlir-opt %s -split-input-file -verify-diagnostics -// ----- - -func @buffer_alloc_single_index() { - // expected-error @+1 {{expected one index operand}} - %0 = linalg.buffer_alloc : !linalg.buffer -} - -// ----- - -func @buffer_alloc_unexpected_index(%s : index) { - // expected-error @+1 {{expected zero operand}} - %0 = linalg.buffer_alloc %s : !linalg.buffer<32xf32> -} - -// ----- - -func @buffer_alloc_nonegative_size() { - // expected-error @+1 {{expected nonnegative static buffer size}} - %0 = linalg.buffer_alloc : !linalg.buffer<0xf32> -} - -// ----- - -func @buffer_alloc_nonegative_alignment(%arg0: index) { - // expected-error @+1 {{expected positive alignment}} - %0 = linalg.buffer_alloc %arg0 {alignment = -123}: !linalg.buffer -} - -// ----- - -func @buffer_alloc_powerof2_alignment(%arg0: index) { - // expected-error @+1 {{expected power of 2 alignment}} - %0 = linalg.buffer_alloc %arg0 {alignment = 123}: !linalg.buffer -} - -// ----- - -func @buffer_valid_element_type() { - // expected-error @+1 {{expected valid buffer element type}} - %0 = linalg.buffer_alloc : !linalg.buffer<4xindex> -} - -// ----- - func @load_number_of_indices(%v : memref) { // expected-error @+2 {{incorrect number of indices for load}} %c0 = constant 0 : index @@ -99,22 +55,6 @@ func @transpose_bad_rank(%v : memref(off + M * i + j)> // ----- -func @view_type(%buf: !linalg.buffer, %min: index, %max: index, %step: index) { - // expected-error @+2 {{expected memref type}} - %r = linalg.range %min:%max:%step : !linalg.range - %0 = linalg.view %buf[%r]: !linalg.buffer -> index -} - -// ----- - -func @view_num_ranges(%buf: !linalg.buffer, %min: index, %max: index, %step: index) { - // expected-error @+2 {{expected 2 ranges}} - %r = linalg.range %min:%max:%step : !linalg.range - %0 = linalg.view %buf[%r]: !linalg.buffer -> memref(off + M * i + j)> -} - -// ----- - func @yield_parent(%arg0: memref(off + i)>) { // expected-error @+1 {{op expected 'linalg.generic' or 'linalg.indexed_generic' parent op}} linalg.yield %arg0: memref(off + i)> @@ -396,15 +336,5 @@ func @generic_fun_result_0_element_type(%arg0: memref(off + i) // ----- -// expected-error @+1 {{expected single element in size list}} -!invalid_type = type !linalg.buffer<1x1xf32> - -// ----- - -// expected-error @+1 {{expected '>'}} -!invalid_type = type !linalg<"buffer<1xf32"> - -// ----- - // expected-error @+1 {{expected valid keyword}} !invalid_type = type !linalg<"?"> diff --git a/mlir/test/Dialect/Linalg/llvm.mlir b/mlir/test/Dialect/Linalg/llvm.mlir index 6cf67c8d3e38..e714fdacda8f 100644 --- a/mlir/test/Dialect/Linalg/llvm.mlir +++ b/mlir/test/Dialect/Linalg/llvm.mlir @@ -1,35 +1,6 @@ // RUN: mlir-opt %s -convert-linalg-to-llvm | FileCheck %s // RUN: mlir-opt %s -linalg-lower-to-loops -convert-linalg-to-llvm | FileCheck %s --check-prefix=LLVM-LOOPS -func @buffer_size(%arg0: !linalg.buffer) { - %c1 = constant 1 : index - %s = linalg.buffer_size %arg0 : !linalg.buffer - %t = addi %s, %c1 : index - return -} -// CHECK-LABEL: func @buffer_size -// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm<"{ i8*, float*, i64 }"> -// CHECK-NEXT: llvm.add {{.*}}, {{.*}} : !llvm.i64 - -func @buffer_alloc_aligned(%arg0: index) { - %s = linalg.buffer_alloc %arg0 {alignment=16} : !linalg.buffer - return -} -// CHECK-LABEL: func @buffer_alloc_aligned -// CHECK: %[[c4:.*]] = llvm.mlir.constant(4 : index) : !llvm.i64 -// CHECK: %[[m:.*]] = llvm.mul %arg0, %[[c4]] : !llvm.i64 -// CHECK: %[[c1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK: %[[c16:.*]] = llvm.mlir.constant(16 : index) : !llvm.i64 -// CHECK: %[[a:.*]] = llvm.add %[[m]], %[[c16]] : !llvm.i64 -// CHECK: %[[s:.*]] = llvm.sub %[[a]], %[[c1]] : !llvm.i64 -// CHECK: %[[alloc:.*]] = llvm.call @malloc(%[[s]]) : (!llvm.i64) -> !llvm<"i8*"> -// aligning `ptr` on `align` is done computing the address `ptr + (align - ptr % align) % align`. -// CHECK: %[[cast:.*]] = llvm.ptrtoint %[[alloc]] : !llvm<"i8*"> to !llvm.i64 -// CHECK: %[[rem:.*]] = llvm.urem %[[cast]], %[[c16]] : !llvm.i64 -// CHECK: %[[drem:.*]] = llvm.sub %[[c16]], %[[rem]] : !llvm.i64 -// CHECK: %[[off:.*]] = llvm.urem %[[drem]], %[[c16]] : !llvm.i64 -// CHECK: llvm.getelementptr %{{.*}}[%[[off]]] : (!llvm<"i8*">, !llvm.i64) -> !llvm<"i8*"> - func @range(%arg0: index) { %c0 = constant 0 : index %c1 = constant 1 : index @@ -44,48 +15,11 @@ func @range(%arg0: index) { // CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ i64, i64, i64 }"> // CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm<"{ i64, i64, i64 }"> -func @view(%arg0: !linalg.buffer, %arg1: !linalg.range) { - %0 = linalg.view %arg0[%arg1] : !linalg.buffer -> memref - return -} -// CHECK-LABEL: func @view -// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm<"{ i8*, float*, i64 }"> -// CHECK-NEXT: llvm.bitcast {{.*}} : !llvm<"float*"> to !llvm<"float*"> -// CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }"> -// CHECK-NEXT: llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }"> -// CHECK-NEXT: llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK-NEXT: llvm.extractvalue %{{.*}}[2] : !llvm<"{ i64, i64, i64 }"> -// CHECK-NEXT: llvm.mul %{{.*}}, %{{.*}} : !llvm.i64 -// CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }"> -// CHECK-NEXT: llvm.extractvalue %{{.*}}[0] : !llvm<"{ i64, i64, i64 }"> -// CHECK-NEXT: llvm.extractvalue %{{.*}}[1] : !llvm<"{ i64, i64, i64 }"> -// CHECK-NEXT: llvm.sub %{{.*}}, %{{.*}} : !llvm.i64 -// CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[2, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }"> -// CHECK-NEXT: llvm.return - -func @view3d(%arg0: !linalg.buffer, %arg1: !linalg.range, %arg2: !linalg.range, %arg3: !linalg.range) { - %0 = linalg.view %arg0[%arg1, %arg2, %arg3] : !linalg.buffer -> memref - return -} -// CHECK-LABEL: func @view3d -// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm<"{ i64, i64, i64 }"> -// CHECK-NEXT: llvm.mul %{{.*}}, %{{.*}} : !llvm.i64 -// CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[3, 2] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> -// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm<"{ i64, i64, i64 }"> -// CHECK: llvm.mul %{{.*}}, %{{.*}} : !llvm.i64 -// CHECK-NEXT: llvm.extractvalue %{{.*}}[2] : !llvm<"{ i64, i64, i64 }"> -// CHECK-NEXT: llvm.mul %{{.*}}, %{{.*}} : !llvm.i64 -// CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> - -func @slice(%arg0: !linalg.buffer, %arg1: !linalg.range) { - %0 = linalg.view %arg0[%arg1] : !linalg.buffer -> memref - %1 = linalg.slice %0[%arg1] : memref, !linalg.range, memref +func @slice(%arg0: memref, %arg1: !linalg.range) { + %1 = linalg.slice %arg0[%arg1] : memref, !linalg.range, memref return } // CHECK-LABEL: func @slice -// insert ptr for view op -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }"> // insert data ptr for slice op // CHECK: llvm.extractvalue %{{.*}}[3, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }"> // CHECK-NEXT: llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }"> @@ -122,13 +56,6 @@ func @dot(%arg0: memref, %arg1: memref, !llvm<"{ float*, i64, [1 x i64], [1 x i64] }*">, !llvm<"{ float*, i64 }*">) -> () -func @dim(%arg0: memref) { - %0 = dim %arg0, 1 : memref - return -} -// CHECK-LABEL: func @dim(%{{.*}}: !llvm<"{ float*, i64, [2 x i64], [2 x i64] }*">) { -// CHECK: llvm.extractvalue %{{.*}}[2, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> - func @subview(%arg0: memref) { %c0 = constant 0 : index %0 = linalg.subview %arg0[%c0, %c0, %c0, %c0, %c0, %c0] : memref diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir index 2b293a94ba98..d62a2885e1e5 100644 --- a/mlir/test/Dialect/Linalg/loops.mlir +++ b/mlir/test/Dialect/Linalg/loops.mlir @@ -12,22 +12,20 @@ // CHECK-DAG: #[[Stride2Dilation4:.*]] = (d0, d1) -> (d0 * 2 + d1 * 4) // CHECK-DAG: #[[Stride3Dilation5:.*]] = (d0, d1) -> (d0 * 3 + d1 * 5) -func @matmul(%arg0: !linalg.buffer, %arg1: index, %arg2: index, %arg3: index) { + +func @matmul(%arg0: memref, %M: index, %N: index, %K: index) { %c0 = constant 0 : index %c1 = constant 1 : index - %I = linalg.range %c0:%arg1:%c1 : !linalg.range - %J = linalg.range %c0:%arg2:%c1 : !linalg.range - %K = linalg.range %c0:%arg3:%c1 : !linalg.range - %A = linalg.view %arg0[%I, %K] : !linalg.buffer -> memref - %B = linalg.view %arg0[%K, %J] : !linalg.buffer -> memref - %C = linalg.view %arg0[%I, %J] : !linalg.buffer -> memref + %A = view %arg0[%M, %K][%c0] : memref to memref + %B = view %arg0[%K, %N][%c0] : memref to memref + %C = view %arg0[%M, %N][%c0] : memref to memref linalg.matmul(%A, %B, %C) : memref, memref, memref return } -// CHECK-LABEL: func @matmul(%{{.*}}: !linalg.buffer, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) { -// CHECK: %[[A:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer -> memref -// CHECK: %[[B:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer -> memref -// CHECK: %[[C:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer -> memref +// CHECK-LABEL: func @matmul(%{{.*}}: memref, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) { +// CHECK: %[[A:.*]] = std.view %{{.*}}[{{.*}}] : memref to memref +// CHECK: %[[B:.*]] = std.view %{{.*}}[{{.*}}] : memref to memref +// CHECK: %[[C:.*]] = std.view %{{.*}}[{{.*}}] : memref to memref // CHECK: %[[M:.*]] = dim %[[A]], 0 : memref // CHECK: %[[K:.*]] = dim %[[A]], 1 : memref // CHECK: %[[N:.*]] = dim %[[B]], 1 : memref @@ -41,21 +39,19 @@ func @matmul(%arg0: !linalg.buffer, %arg1: index, %arg2: index, %arg3: in // CHECK-DAG: %[[res:.*]] = addf %[[c]], %[[inc]] : f32 // CHECK: store %[[res]], %[[C]][%{{.*}}, %{{.*}}] : memref -func @matvec(%arg0: !linalg.buffer, %arg1: index, %arg2: index, %arg3: index) { +func @matvec(%arg0: memref, %M: index, %N: index) { %c0 = constant 0 : index %c1 = constant 1 : index - %I = linalg.range %c0:%arg1:%c1 : !linalg.range - %J = linalg.range %c0:%arg2:%c1 : !linalg.range - %2 = linalg.view %arg0[%I, %J] : !linalg.buffer -> memref - %3 = linalg.view %arg0[%J] : !linalg.buffer -> memref - %4 = linalg.view %arg0[%I] : !linalg.buffer -> memref + %2 = view %arg0[%M, %N][%c0] : memref to memref + %3 = view %arg0[%M][%c0] : memref to memref + %4 = view %arg0[%N][%c0] : memref to memref linalg.matvec(%2, %3, %4) : memref, memref, memref return } -// CHECK-LABEL: func @matvec(%{{.*}}: !linalg.buffer, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) { -// CHECK: %[[A:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer -> memref -// CHECK: %[[B:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer -> memref -// CHECK: %[[C:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer -> memref +// CHECK-LABEL: func @matvec(%{{.*}}: memref, %{{.*}}: index, %{{.*}}: index) { +// CHECK: %[[A:.*]] = std.view %{{.*}}[{{.*}}] : memref to memref +// CHECK: %[[B:.*]] = std.view %{{.*}}[{{.*}}] : memref to memref +// CHECK: %[[C:.*]] = std.view %{{.*}}[{{.*}}] : memref to memref // CHECK: %[[M:.*]] = dim %[[A]], 0 : memref // CHECK: %[[K:.*]] = dim %[[A]], 1 : memref // CHECK: loop.for %{{.*}} = %{{.*}} to %[[M]] step %{{.*}} { @@ -67,20 +63,19 @@ func @matvec(%arg0: !linalg.buffer, %arg1: index, %arg2: index, %arg3: in // CHECK-DAG: %[[res:.*]] = addf %[[c]], %[[inc]] : f32 // CHECK: store %[[res]], %[[C]][%{{.*}}] : memref -func @dot(%arg0: !linalg.buffer, %arg1: index, %arg2: index, %arg3: index) { +func @dot(%arg0: memref, %M: index) { %c0 = constant 0 : index %c1 = constant 1 : index - %I = linalg.range %c0:%arg1:%c1 : !linalg.range - %1 = linalg.view %arg0[%I] : !linalg.buffer -> memref - %2 = linalg.view %arg0[%I] : !linalg.buffer -> memref - %3 = linalg.view %arg0[] : !linalg.buffer -> memref + %1 = view %arg0[%M][%c0] : memref to memref + %2 = view %arg0[%M][%c0] : memref to memref + %3 = view %arg0[][] : memref to memref linalg.dot(%1, %2, %3) : memref, memref, memref return } -// CHECK-LABEL: func @dot(%{{.*}}: !linalg.buffer, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) { -// CHECK: %[[A:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer -> memref -// CHECK: %[[B:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer -> memref -// CHECK: %[[C:.*]] = linalg.view %arg0[] : !linalg.buffer -> memref +// CHECK-LABEL: func @dot(%{{.*}}: memref, %{{.*}}: index) { +// CHECK: %[[A:.*]] = std.view %{{.*}}[{{.*}}][{{.*}}] : memref to memref +// CHECK: %[[B:.*]] = std.view %{{.*}}[{{.*}}][{{.*}}] : memref to memref +// CHECK: %[[C:.*]] = std.view %{{.*}}[][] : memref to memref // CHECK: %[[K:.*]] = dim %[[A]], 0 : memref // CHECK: loop.for %{{.*}} = %{{.*}} to %[[K]] step %{{.*}} { // CHECK-DAG: %[[a:.*]] = load %[[A]][%{{.*}}] : memref diff --git a/mlir/test/Dialect/Linalg/promote.mlir b/mlir/test/Dialect/Linalg/promote.mlir index 91680d75ee1a..51f8b35012e0 100644 --- a/mlir/test/Dialect/Linalg/promote.mlir +++ b/mlir/test/Dialect/Linalg/promote.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt %s -linalg-promote-subviews | FileCheck %s +// RUN: mlir-opt %s -linalg-promote-subviews -test-linalg-promote-dynamic | FileCheck %s --check-prefix=DYNAMIC #map0 = (d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1) #map1 = (d0) -> (d0 + 2) @@ -9,18 +10,15 @@ // CHECK-DAG: #[[strided2DnoOffset:.*]] = (d0, d1)[s0] -> (d0 * s0 + d1) module { - func @matmul(%arg0: !linalg.buffer, %arg1: index, %arg2: index, %arg3: index) { + func @matmul(%A: memref, %M: index, %N: index, %K: index) { %c4 = constant 4 : index %c3 = constant 3 : index %c2 = constant 2 : index %c0 = constant 0 : index %c1 = constant 1 : index - %0 = linalg.range %c0:%arg1:%c1 : !linalg.range - %1 = linalg.range %c0:%arg2:%c1 : !linalg.range - %2 = linalg.range %c0:%arg3:%c1 : !linalg.range - %3 = linalg.view %arg0[%0, %2] : !linalg.buffer -> memref - %4 = linalg.view %arg0[%2, %1] : !linalg.buffer -> memref - %5 = linalg.view %arg0[%0, %1] : !linalg.buffer -> memref + %3 = view %A[%M, %K][%c0] : memref to memref + %4 = view %A[%K, %N][%c0] : memref to memref + %5 = view %A[%M, %N][%c0] : memref to memref %6 = dim %3, 0 : memref %7 = dim %3, 1 : memref %8 = dim %4, 1 : memref @@ -44,7 +42,7 @@ module { } } -// CHECK-LABEL: func @matmul(%{{.*}}: !linalg.buffer, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) { +// CHECK-LABEL: func @matmul(%{{.*}}: memref, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) { // CHECK: loop.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { // CHECK: loop.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { // CHECK: loop.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { @@ -52,16 +50,19 @@ module { // CHECK: %[[vB:.*]] = linalg.subview {{.*}} : memref // CHECK: %[[vC:.*]] = linalg.subview {{.*}} : memref /// -// CHECK: %[[tmpA:.*]] = linalg.buffer_alloc : !linalg.buffer<8xf32> -// CHECK: %[[fullA:.*]] = linalg.view %[[tmpA]][{{.*}}] : !linalg.buffer<8xf32> -> memref +// CHECK: %[[tmpA:.*]] = alloc() : memref<32xi8> +// CHECK: %[[fullA:.*]] = std.view %[[tmpA]][{{.*}}][] : memref<32xi8> to memref +// DYNAMIC: std.view %{{.*}}[{{.*}}][] : memref to memref // CHECK: %[[partialA:.*]] = linalg.slice %[[fullA]][%{{.*}}, %{{.*}}] : memref, !linalg.range, !linalg.range, memref /// -// CHECK: %[[tmpB:.*]] = linalg.buffer_alloc : !linalg.buffer<12xf32> -// CHECK: %[[fullB:.*]] = linalg.view %[[tmpB]][{{.*}}] : !linalg.buffer<12xf32> -> memref +// CHECK: %[[tmpB:.*]] = alloc() : memref<48xi8> +// CHECK: %[[fullB:.*]] = std.view %[[tmpB]][{{.*}}][] : memref<48xi8> to memref +// DYNAMIC: std.view %{{.*}}[{{.*}}][] : memref to memref // CHECK: %[[partialB:.*]] = linalg.slice %[[fullB]][%{{.*}}, %{{.*}}] : memref, !linalg.range, !linalg.range, memref /// -// CHECK: %[[tmpC:.*]] = linalg.buffer_alloc : !linalg.buffer<6xf32> -// CHECK: %[[fullC:.*]] = linalg.view %[[tmpC]][{{.*}}] : !linalg.buffer<6xf32> -> memref +// CHECK: %[[tmpC:.*]] = alloc() : memref<24xi8> +// CHECK: %[[fullC:.*]] = std.view %[[tmpC]][{{.*}}][] : memref<24xi8> to memref +// DYNAMIC: std.view %{{.*}}[{{.*}}][] : memref to memref // CHECK: %[[partialC:.*]] = linalg.slice %[[fullC]][%{{.*}}, %{{.*}}] : memref, !linalg.range, !linalg.range, memref // CHECK: linalg.fill(%[[fullA]], {{.*}}) : memref, f32 @@ -75,6 +76,6 @@ module { // // CHECK: linalg.copy(%[[partialC]], %[[vC]]) : memref, memref // -// CHECK: linalg.buffer_dealloc %[[tmpA]] : !linalg.buffer<8xf32> -// CHECK: linalg.buffer_dealloc %[[tmpB]] : !linalg.buffer<12xf32> -// CHECK: linalg.buffer_dealloc %[[tmpC]] : !linalg.buffer<6xf32> +// CHECK: dealloc %[[tmpA]] : memref<32xi8> +// CHECK: dealloc %[[tmpB]] : memref<48xi8> +// CHECK: dealloc %[[tmpC]] : memref<24xi8> diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir index 0895ab717ce8..bffe599e723b 100644 --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -7,7 +7,6 @@ // CHECK-DAG: #[[strided1D:.*]] = (d0)[s0] -> (d0 + s0) // CHECK-DAG: #[[strided2D:.*]] = (d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1) -// CHECK-DAG: #[[strided2D42by1SymbolicOffset:.*]] = (d0, d1)[s0] -> (d0 * 42 + s0 + d1) // CHECK-DAG: #[[strided3D:.*]] = (d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2) // CHECK-DAG: #[[strided6D:.*]] = (d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4 + d4 * s5 + d5) @@ -21,60 +20,31 @@ func @range(%arg0: index, %arg1: index, %arg2: index) { // CHECK-LABEL: func @range(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index) { // CHECK-NEXT: linalg.range %{{.*}}:%{{.*}}:%{{.*}} : !linalg.range -func @buffer_size(%arg0: !linalg.buffer) -> index { - %0 = linalg.buffer_size %arg0 : !linalg.buffer - return %0 : index -} -// CHECK-LABEL: func @buffer_size -// CHECK: linalg.buffer_size {{.*}} : !linalg.buffer - -func @buffer(%arg0: index, %arg1: index) { - %0 = muli %arg0, %arg0 : index - %1 = linalg.buffer_alloc %0 : !linalg.buffer> - %2 = linalg.buffer_alloc %0 {alignment = 16} : !linalg.buffer> - %3 = linalg.buffer_alloc : !linalg.buffer<17xvector<4xi8>> - %4 = linalg.buffer_alloc {alignment = 32} : !linalg.buffer<17xvector<4xi8>> - linalg.buffer_dealloc %4 : !linalg.buffer<17xvector<4xi8>> - linalg.buffer_dealloc %3 : !linalg.buffer<17xvector<4xi8>> - linalg.buffer_dealloc %2 : !linalg.buffer> - linalg.buffer_dealloc %1 : !linalg.buffer> - return -} -// CHECK-LABEL: func @buffer(%{{.*}}: index, %{{.*}}: index) { -// CHECK-NEXT: muli %{{.*}}, %{{.*}} : index -// CHECK-NEXT: linalg.buffer_alloc %{{.*}} : !linalg.buffer> -// CHECK-NEXT: linalg.buffer_alloc %{{.*}} {alignment = 16 : i64} : !linalg.buffer> -// CHECK-NEXT: linalg.buffer_alloc : !linalg.buffer<17xvector<4xi8>> -// CHECK-NEXT: linalg.buffer_alloc {alignment = 32 : i64} : !linalg.buffer<17xvector<4xi8>> -// CHECK-NEXT: linalg.buffer_dealloc %{{.*}} : !linalg.buffer<17xvector<4xi8>> -// CHECK-NEXT: linalg.buffer_dealloc %{{.*}} : !linalg.buffer<17xvector<4xi8>> -// CHECK-NEXT: linalg.buffer_dealloc %{{.*}} : !linalg.buffer> -// CHECK-NEXT: linalg.buffer_dealloc %{{.*}} : !linalg.buffer> - func @views(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) { + %c0 = constant 0 : index %0 = muli %arg0, %arg0 : index - %1 = linalg.buffer_alloc %0 : !linalg.buffer - %2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range - %3 = linalg.view %1[%2, %2] : !linalg.buffer -> memref + %1 = alloc (%0) : memref + %2 = linalg.range %arg0:%arg1:%arg2 : !linalg.range + %3 = view %1[%arg0, %arg0][%c0] : memref to memref %4 = linalg.slice %3[%2, %2] : memref, !linalg.range, !linalg.range, memref %5 = linalg.slice %3[%2, %arg2] : memref, !linalg.range, index, memref %6 = linalg.slice %3[%arg2, %2] : memref, index, !linalg.range, memref %7 = linalg.slice %3[%arg2, %arg3] : memref, index, index, memref - %8 = linalg.view %1[%2, %2] : !linalg.buffer -> memref, offset: ?, strides: [?, 1]> - linalg.buffer_dealloc %1 : !linalg.buffer + %8 = view %1[%arg0, %arg0][%c0] : memref to memref, offset: ?, strides: [?, 1]> + dealloc %1 : memref return } // CHECK-LABEL: func @views(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) { -// CHECK-NEXT: muli %{{.*}}, %{{.*}} : index -// CHECK-NEXT: linalg.buffer_alloc %{{.*}} : !linalg.buffer -// CHECK-NEXT: linalg.range %{{.*}}:%{{.*}}:%{{.*}} : !linalg.range -// CHECK-NEXT: linalg.view %{{.*}}[%{{.*}}, %{{.*}}] : !linalg.buffer -> memref +// CHECK: muli %{{.*}}, %{{.*}} : index +// CHECK-NEXT: alloc(%{{.*}}) : memref +// CHECK-NEXT: range +// CHECK-NEXT: std.view %{{.*}}[%{{.*}}][%{{.*}}] : memref to memref // CHECK-NEXT: linalg.slice %{{.*}}[%{{.*}}, %{{.*}}] : memref, !linalg.range, !linalg.range, memref // CHECK-NEXT: linalg.slice %{{.*}}[%{{.*}}, %{{.*}}] : memref, !linalg.range, index, memref // CHECK-NEXT: linalg.slice %{{.*}}[%{{.*}}, %{{.*}}] : memref, index, !linalg.range, memref // CHECK-NEXT: linalg.slice %{{.*}}[%{{.*}}, %{{.*}}] : memref, index, index, memref -// CHECK-NEXT: linalg.view %{{.*}}[%{{.*}}, %{{.*}}] : !linalg.buffer -> memref, #[[strided2D]]> -// CHECK-NEXT: linalg.buffer_dealloc %{{.*}} : !linalg.buffer +// CHECK-NEXT: view %{{.*}}[%{{.*}}][%{{.*}}] : memref to memref, #[[strided2D]]> +// CHECK-NEXT: dealloc %{{.*}} : memref func @ops(%arg0: memref, %arg1: memref, %arg2: memref, %arg3: memref) { linalg.matmul(%arg0, %arg0, %arg0) : memref, memref, memref @@ -88,41 +58,6 @@ func @ops(%arg0: memref, %arg1: memref, memref, memref // CHECK-NEXT: linalg.dot(%{{.*}}, %{{.*}}, %{{.*}}) : memref, memref, memref -func @dim(%arg0: memref) { - %0 = dim %arg0, 1 : memref - %1 = linalg.buffer_alloc %0 : !linalg.buffer - linalg.buffer_dealloc %1 : !linalg.buffer - return -} -// CHECK-LABEL: func @dim( -// CHECK: %{{.*}}: memref) { -// CHECK-NEXT: dim %{{.*}}, 1 : memref -// CHECK-NEXT: linalg.buffer_alloc %{{.*}} : !linalg.buffer -// CHECK-NEXT: linalg.buffer_dealloc %{{.*}} : !linalg.buffer - -func @linalg_for(%arg0 : index, %arg1 : index, %arg2 : index) { - loop.for %i0 = %arg0 to %arg1 step %arg2 { - loop.for %i1 = %arg0 to %arg1 step %arg2 { - %min_cmp = cmpi "slt", %i0, %i1 : index - %min = select %min_cmp, %i0, %i1 : index - %max_cmp = cmpi "sge", %i0, %i1 : index - %max = select %max_cmp, %i0, %i1 : index - loop.for %i2 = %min to %max step %i1 { - } - } - } - return -} -// CHECK-LABEL: func @linalg_for( -// CHECK: %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) { -// CHECK-NEXT: loop.for %{{.*}} to %{{.*}} step %{{.*}} { -// CHECK-NEXT: loop.for %{{.*}} to %{{.*}} step %{{.*}} { -// CHECK-NEXT: cmpi "slt", %{{.*}}, %{{.*}} : index -// CHECK-NEXT: select %{{.*}}, %{{.*}}, %{{.*}} : index -// CHECK-NEXT: cmpi "sge", %{{.*}}, %{{.*}} : index -// CHECK-NEXT: select %{{.*}}, %{{.*}}, %{{.*}} : index -// CHECK-NEXT: loop.for %{{.*}} to %{{.*}} step %{{.*}} { - func @fill_view(%arg0: memref, %arg1: f32) { linalg.fill(%arg0, %arg1) : memref, f32 return @@ -190,13 +125,6 @@ func @subview(%arg0: memref, offset: ?, strides: [?, 1]>) { // CHECK: constant 0 : index // CHECK: linalg.subview %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : memref, #[[strided2D]]> -func @const_buffer_view(%arg0: index, %arg1: index, %arg2: index) { - %c0 = linalg.buffer_alloc : !linalg.buffer<17xf32> - %c1 = linalg.range %arg0:%arg1:%arg2 : !linalg.range - %c2 = linalg.view %c0[%c1] : !linalg.buffer<17xf32> -> memref - return -} - #accesses = [ (i, j, k) -> (j, i), (i, j, k) -> (i, k, i + j) diff --git a/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir b/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir index c533ed1868d4..b153faea5e67 100644 --- a/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir +++ b/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir @@ -5,18 +5,19 @@ // RUN: mlir-opt %s -linalg-tile -linalg-tile-sizes=2,3,4 -linalg-promote-subviews -linalg-lower-to-loops -convert-linalg-to-llvm | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s // RUN: mlir-opt %s -linalg-tile -linalg-tile-sizes=2,3,4 -linalg-promote-subviews -convert-linalg-to-llvm | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s -#strided1D = (d0)[s0] -> (d0 + s0) -#strided2D = (d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1) +#strided1D = (d0) -> (d0) +#strided2D = (d0, d1)[s0] -> (d0 * s0 + d1) // Creates and returns a 1-D buffer of size %s filled with the value %f -func @alloc_filled_f32(%s : index, %f : f32) -> !linalg.buffer { +func @alloc_filled_f32(%s : index, %f : f32) -> memref { %c0 = constant 0 : index %c1 = constant 1 : index - %buf = linalg.buffer_alloc %s {alignment = 256} : !linalg.buffer - %R = linalg.range %c0:%s:%c1 : !linalg.range - %V = linalg.view %buf[%R] : !linalg.buffer -> memref + %c4 = constant 4 : index + %s4 = muli %s, %c4: index + %buf = alloc(%s4) {alignment = 256} : memref + %V = view %buf[%s][] : memref to memref linalg.fill(%V, %f) : memref, f32 - return %buf : !linalg.buffer + return %buf : memref } // Test for linalg.dot. @@ -28,21 +29,20 @@ func @dot() -> f32 { %f1 = constant 1.00000e+00 : f32 %f2 = constant 2.00000e+00 : f32 - %bA = call @alloc_filled_f32(%c16, %f2) : (index, f32) -> (!linalg.buffer) - %bB = call @alloc_filled_f32(%c16, %f1) : (index, f32) -> (!linalg.buffer) - %bC = call @alloc_filled_f32(%c1, %f10) : (index, f32) -> (!linalg.buffer) + %bA = call @alloc_filled_f32(%c16, %f2) : (index, f32) -> (memref) + %bB = call @alloc_filled_f32(%c16, %f1) : (index, f32) -> (memref) + %bC = call @alloc_filled_f32(%c1, %f10) : (index, f32) -> (memref) - %R = linalg.range %c0:%c16:%c1 : !linalg.range - %A = linalg.view %bA[%R] : !linalg.buffer -> memref - %B = linalg.view %bB[%R] : !linalg.buffer -> memref - %C = linalg.view %bC[] : !linalg.buffer -> memref + %A = view %bA[%c16][] : memref to memref + %B = view %bB[%c16][] : memref to memref + %C = view %bC[][] : memref to memref linalg.dot(%A, %B, %C) : memref, memref, memref %res = load %C[] : memref - linalg.buffer_dealloc %bC : !linalg.buffer - linalg.buffer_dealloc %bB : !linalg.buffer - linalg.buffer_dealloc %bA : !linalg.buffer + dealloc %bC : memref + dealloc %bB : memref + dealloc %bA : memref return %res : f32 } @@ -61,23 +61,20 @@ func @matmul() -> f32 { %f2 = constant 2.00000e+00 : f32 %f10 = constant 10.00000e+00 : f32 - %bA = call @alloc_filled_f32(%c160, %f2) : (index, f32) -> (!linalg.buffer) - %bB = call @alloc_filled_f32(%c160, %f1) : (index, f32) -> (!linalg.buffer) - %bC = call @alloc_filled_f32(%c100, %f10) : (index, f32) -> (!linalg.buffer) + %bA = call @alloc_filled_f32(%c160, %f2) : (index, f32) -> (memref) + %bB = call @alloc_filled_f32(%c160, %f1) : (index, f32) -> (memref) + %bC = call @alloc_filled_f32(%c100, %f10) : (index, f32) -> (memref) - %M = linalg.range %c0:%c10:%c1 : !linalg.range - %N = linalg.range %c0:%c10:%c1 : !linalg.range - %K = linalg.range %c0:%c16:%c1 : !linalg.range - %A = linalg.view %bA[%M, %K] : !linalg.buffer -> memref - %B = linalg.view %bB[%K, %N] : !linalg.buffer -> memref - %C = linalg.view %bC[%M, %N] : !linalg.buffer -> memref + %A = view %bA[%c10, %c16][] : memref to memref + %B = view %bB[%c16, %c10][] : memref to memref + %C = view %bC[%c10, %c10][] : memref to memref linalg.matmul(%A, %B, %C) : memref, memref, memref %res = load %C[%c6, %c7] : memref - linalg.buffer_dealloc %bC : !linalg.buffer - linalg.buffer_dealloc %bB : !linalg.buffer - linalg.buffer_dealloc %bA : !linalg.buffer + dealloc %bC : memref + dealloc %bB : memref + dealloc %bA : memref return %res : f32 }