From 6a7a1ca25d0f3fa76bd05b10f35678b3cc3defca Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Sun, 30 Jun 2019 07:07:50 -0700 Subject: [PATCH] Move BufferAllocOp and BufferDeallocOp to ODS This CL also fixes a parsing issue in the BufferType, adds LLVM lowering support for handling the static constant buffer size and a roundtrip test. PiperOrigin-RevId: 255834356 --- mlir/include/mlir/Linalg/IR/LinalgOps.h | 50 ------- mlir/include/mlir/Linalg/IR/LinalgOps.td | 59 +++++++++ mlir/lib/Linalg/IR/LinalgOps.cpp | 124 ++++++++---------- mlir/lib/Linalg/IR/LinalgTypes.cpp | 7 +- .../Linalg/Transforms/LowerToLLVMDialect.cpp | 13 +- mlir/test/Linalg/roundtrip.mlir | 4 + 6 files changed, 129 insertions(+), 128 deletions(-) diff --git a/mlir/include/mlir/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Linalg/IR/LinalgOps.h index ff04c366ebf5..72eabc2f1987 100644 --- a/mlir/include/mlir/Linalg/IR/LinalgOps.h +++ b/mlir/include/mlir/Linalg/IR/LinalgOps.h @@ -29,56 +29,6 @@ class OperationFolder; namespace linalg { -/// The "buffer_alloc" op creates a 1-D linalg.buffer of the specified type, -/// upon which a base view can be laid out to give it indexing semantics. -/// "buffer_alloc" takes a single argument, the size of the buffer to allocate -/// (in number of elements). -/// -/// ```{.mlir} -/// %0 = linalg.buffer_alloc %arg0 : !linalg.buffer -/// ``` -class BufferAllocOp - : public Op { -public: - using Op::Op; - - // Hooks to customize the behavior of this op. - static llvm::StringRef getOperationName() { return "linalg.buffer_alloc"; } - static void build(Builder *b, OperationState *result, Type type, Value *size); - LogicalResult verify(); - static ParseResult parse(OpAsmParser *parser, OperationState *result); - void print(OpAsmPrinter *p); - - // Op-specific functionality. - Value *size() { return getOperand(); } - BufferType getBufferType() { return getType().cast(); } - Type getElementType() { return getBufferType().getElementType(); } -}; - -/// The "buffer_dealloc" op frees a 1-D linalg.buffer of the specified type. -/// -/// ```{.mlir} -/// linalg.buffer_dealloc %0 : !linalg.buffer -/// ``` -class BufferDeallocOp - : public Op { -public: - using Op::Op; - - // Hooks to customize the behavior of this op. - static llvm::StringRef getOperationName() { return "linalg.buffer_dealloc"; } - static void build(Builder *b, OperationState *result, Value *buffer); - LogicalResult verify(); - static ParseResult parse(OpAsmParser *parser, OperationState *result); - void print(OpAsmPrinter *p); - - // Op-specific functionality. - Value *getBuffer() { return getOperand(); } - BufferType getBufferType() { - return getOperand()->getType().cast(); - } -}; - /// The "linalg.for" operation represents a loop nest taking 3 SSA value as /// operands that represent the lower bound, upper bound and step respectively. /// The operation defines an SSA value for its induction variable. It has one diff --git a/mlir/include/mlir/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Linalg/IR/LinalgOps.td index 051d16e8d9ee..49b75be48818 100644 --- a/mlir/include/mlir/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Linalg/IR/LinalgOps.td @@ -39,6 +39,65 @@ class Linalg_Op traits = []> : let parser = [{ return ::parse$cppClass(parser, result); }]; } +def BufferAllocOp : + Linalg_Op<"buffer_alloc">, + Arguments<(ins Variadic:$size)>, + Results<(outs Buffer)> { + let summary = "buffer allocation operation"; + let description = [{ + The "buffer_alloc" op creates a 1-D linalg.buffer of the specified type, + upon which a base view can be laid out to give it indexing semantics. + "buffer_alloc" takes a single argument, the size of the buffer to allocate + (in number of elements). + + ```{.mlir} + %0 = linalg.buffer_alloc(%arg0) : !linalg.buffer + ``` + + The size argument may be omitted if it is statically known, in which case it + must be reflected in the type. + + ```{.mlir} + %0 = linalg.buffer_alloc() : !linalg.buffer<4xf32> + ``` + }]; + let builders = [OpBuilder< + "Builder *builder, OperationState *result, BufferType bufferType", [{ + result->types.push_back(bufferType); + }] + >]; + let extraClassDeclaration = [{ + BufferType getBufferType() { return getType().cast(); } + Type getElementType() { return getBufferType().getElementType(); } + }]; +} + +def BufferDeallocOp : + Linalg_Op<"buffer_dealloc">, + Arguments<(ins Buffer:$buffer)>, + Results<(outs)> { + let summary = "buffer allocation operation"; + let description = [{ + The "buffer_dealloc" op frees a 1-D linalg.buffer of the specified type. + + ```{.mlir} + linalg.buffer_dealloc %0 : !linalg.buffer + ``` + }]; + let builders = [OpBuilder< + "Builder *builder, OperationState *result, BufferType bufferType", [{ + result->types.push_back(bufferType); + }] + >]; + let extraClassDeclaration = [{ + BufferType getBufferType() { + return getOperand()->getType().cast(); + } + }]; + // Fully specified by traits. + let verifier = ?; +} + def BufferSizeOp : Linalg_Op<"buffer_size", [NoSideEffect]>, Arguments<(ins Buffer)>, diff --git a/mlir/lib/Linalg/IR/LinalgOps.cpp b/mlir/lib/Linalg/IR/LinalgOps.cpp index 232379fa3feb..c23c60157206 100644 --- a/mlir/lib/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Linalg/IR/LinalgOps.cpp @@ -37,76 +37,6 @@ using namespace mlir::edsc; using namespace mlir::edsc::intrinsics; using namespace mlir::linalg; -////////////////////////////////////////////////////////////////////////////// -// BufferAllocOp -////////////////////////////////////////////////////////////////////////////// -void mlir::linalg::BufferAllocOp::build(Builder *b, OperationState *result, - Type type, Value *size) { - result->addOperands({size}); - result->addTypes(type); -} - -LogicalResult mlir::linalg::BufferAllocOp::verify() { - if (!size() || !size()->getType().isa()) - return emitOpError("first operand should be of type index"); - if (!VectorType::isValidElementType(getElementType()) && - !getElementType().isa()) - return emitOpError("unsupported buffer element type"); - return success(); -} - -// A BufferAllocOp prints as: -// -// ```{.mlir} -// linalg.alloc %0 : !linalg.buffer -// ``` -void mlir::linalg::BufferAllocOp::print(OpAsmPrinter *p) { - *p << getOperationName() << " " << *size() << " : " << getType(); -} - -ParseResult mlir::linalg::BufferAllocOp::parse(OpAsmParser *parser, - OperationState *result) { - OpAsmParser::OperandType sizeInfo; - BufferType bufferType; - auto indexTy = parser->getBuilder().getIndexType(); - if (parser->parseOperand(sizeInfo) || parser->parseColonType(bufferType)) - return failure(); - return failure(parser->resolveOperands(sizeInfo, indexTy, result->operands) || - parser->addTypeToList(bufferType, result->types)); -} - -////////////////////////////////////////////////////////////////////////////// -// BufferDeallocOp -////////////////////////////////////////////////////////////////////////////// -void mlir::linalg::BufferDeallocOp::build(Builder *b, OperationState *result, - Value *buffer) { - result->addOperands({buffer}); -} - -LogicalResult mlir::linalg::BufferDeallocOp::verify() { - if (!getBuffer()->getType()) - return emitOpError("first operand should be of type buffer"); - return success(); -} - -// A BufferDeallocOp prints as: -// -// ```{.mlir} -// linalg.dealloc %0 : !linalg.buffer -// ``` -void mlir::linalg::BufferDeallocOp::print(OpAsmPrinter *p) { - *p << getOperationName() << " " << *getBuffer() << " : " << getBufferType(); -} - -ParseResult mlir::linalg::BufferDeallocOp::parse(OpAsmParser *parser, - OperationState *result) { - OpAsmParser::OperandType sizeInfo; - BufferType bufferType; - return failure( - parser->parseOperand(sizeInfo) || parser->parseColonType(bufferType) || - parser->resolveOperands(sizeInfo, bufferType, result->operands)); -} - //////////////////////////////////////////////////////////////////////////////// // ForOp. //////////////////////////////////////////////////////////////////////////////// @@ -605,6 +535,60 @@ void mlir::linalg::ViewOp::print(OpAsmPrinter *p) { // LinalgOps.td), we define an overloaded `print` function and a // parse`className` function. +static void print(OpAsmPrinter *p, BufferAllocOp op) { + *p << op.getOperationName() << " "; + if (!llvm::empty(op.size())) + *p << *op.getOperand(0); + p->printOptionalAttrDict(op.getAttrs()); + *p << " : " << op.getBufferType(); +} + +static ParseResult parseBufferAllocOp(OpAsmParser *parser, + OperationState *result) { + SmallVector sizeInfo; + BufferType bufferType; + auto indexTy = parser->getBuilder().getIndexType(); + if (parser->parseOperandList(sizeInfo) || parser->parseColonType(bufferType)) + return failure(); + if (sizeInfo.empty()) + return parser->addTypeToList(bufferType, result->types); + return failure(parser->resolveOperands(sizeInfo, indexTy, result->operands) || + parser->addTypeToList(bufferType, result->types)); +} + +static LogicalResult verify(BufferAllocOp op) { + if (!op.getBufferType().hasConstantSize()) { + if (llvm::size(op.size()) != 1 || + !op.getOperand(0)->getType().isa()) + return op.emitOpError( + "one operand of type index expected for dynamic buffer"); + } else { // op.getBufferType().hasConstantSize() + if (!llvm::empty(op.size())) + return op.emitOpError("unexpected static buffer operand"); + if (op.getBufferType().getBufferSize().getValue() <= 0) + return op.emitOpError("expected nonnegative static buffer size"); + } + if (!VectorType::isValidElementType(op.getElementType()) && + !op.getElementType().isa()) + return op.emitOpError("unsupported buffer element type"); + return success(); +} + +static void print(OpAsmPrinter *p, BufferDeallocOp op) { + *p << op.getOperationName() << " " << *op.buffer(); + p->printOptionalAttrDict(op.getAttrs()); + *p << " : " << op.getBufferType(); +} + +static ParseResult parseBufferDeallocOp(OpAsmParser *parser, + OperationState *result) { + OpAsmParser::OperandType bufferInfo; + BufferType bufferType; + if (parser->parseOperand(bufferInfo) || parser->parseColonType(bufferType)) + return failure(); + return parser->resolveOperands(bufferInfo, bufferType, result->operands); +} + static void print(OpAsmPrinter *p, BufferSizeOp op) { *p << op.getOperationName() << " " << *op.getOperand(); p->printOptionalAttrDict(op.getAttrs()); diff --git a/mlir/lib/Linalg/IR/LinalgTypes.cpp b/mlir/lib/Linalg/IR/LinalgTypes.cpp index 82be170df969..9cf9c558ba70 100644 --- a/mlir/lib/Linalg/IR/LinalgTypes.cpp +++ b/mlir/lib/Linalg/IR/LinalgTypes.cpp @@ -34,8 +34,7 @@ using namespace mlir::linalg; mlir::linalg::LinalgDialect::LinalgDialect(MLIRContext *context) : Dialect(getDialectNamespace(), context) { addTypes(); - addOperations(); + addOperations(); addOperations< #define GET_OP_LIST #include "mlir/Linalg/IR/LinalgOps.cpp.inc" @@ -119,8 +118,8 @@ Type mlir::linalg::LinalgDialect::parseType(StringRef spec, // Check for '?' int64_t bufferSize = -1; if (!spec.consume_front("?")) { - unsigned parsedBufferSize; - if (!spec.consumeInteger(10, parsedBufferSize)) { + unsigned long long parsedBufferSize = 0; + if (spec.consumeInteger(10, parsedBufferSize)) { emitError(loc, "expected buffer size to be an unsigned integer"); return Type(); } diff --git a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp index d43a2e622b4d..a8099aaff992 100644 --- a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp +++ b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp @@ -168,7 +168,7 @@ public: auto indexType = IndexType::get(op->getContext()); auto voidPtrTy = LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo(); - auto int64Ty = lowering.convertType(operands[0]->getType()); + auto int64Ty = lowering.convertType(rewriter.getIntegerType(64)); // Insert the `malloc` declaration if it is not already present. auto *module = op->getFunction()->getModule(); Function *mallocFunc = module->getNamedFunction("malloc"); @@ -187,14 +187,19 @@ public: llvm::divideCeil(vectorType.getElementTypeBitWidth(), 8); else elementSize = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8); - auto elementPtrType = getPtrToElementType( - allocOp.getResult()->getType().cast(), lowering); + auto bufferType = allocOp.getResult()->getType().cast(); + auto elementPtrType = getPtrToElementType(bufferType, lowering); auto bufferDescriptorType = convertLinalgType(allocOp.getResult()->getType(), lowering); // Emit IR for creating a new buffer descriptor with an underlying malloc. edsc::ScopedContext context(rewriter, op->getLoc()); - Value *size = operands[0]; + auto constantSize = bufferType.getBufferSize(); + Value *size = + constantSize + ? constant(int64Ty, IntegerAttr::get(indexType, *constantSize)) + .getValue() + : operands[0]; Value *allocSize = mul(size, constant(int64Ty, IntegerAttr::get(indexType, elementSize))); Value *allocated = diff --git a/mlir/test/Linalg/roundtrip.mlir b/mlir/test/Linalg/roundtrip.mlir index 2133b2498c24..c927818da429 100644 --- a/mlir/test/Linalg/roundtrip.mlir +++ b/mlir/test/Linalg/roundtrip.mlir @@ -13,12 +13,16 @@ func @range(%arg0: index, %arg1: index, %arg2: index) { func @buffer(%arg0: index, %arg1: index) { %0 = muli %arg0, %arg0 : index %1 = linalg.buffer_alloc %0 : !linalg.buffer> + %2 = linalg.buffer_alloc : !linalg.buffer<17xvector<4xi8>> + linalg.buffer_dealloc %2 : !linalg.buffer<17xvector<4xi8>> linalg.buffer_dealloc %1 : !linalg.buffer> return } // CHECK-LABEL: func @buffer(%arg0: index, %arg1: index) { // CHECK-NEXT: %0 = muli %arg0, %arg0 : index // CHECK-NEXT: %1 = linalg.buffer_alloc %0 : !linalg.buffer> +// CHECK-NEXT: %2 = linalg.buffer_alloc : !linalg.buffer<17xvector<4xi8>> +// CHECK-NEXT: linalg.buffer_dealloc %2 : !linalg.buffer<17xvector<4xi8>> // CHECK-NEXT: linalg.buffer_dealloc %1 : !linalg.buffer> func @view_fun(%arg0: !linalg.view>) {