From cf3959f49d193132a55bb9a626306de2ff653041 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Mon, 13 May 2019 14:59:55 -0700 Subject: [PATCH] Add a linalg.dim A linalg.dim operation is used to extract size information from !linalg.view objects passed through function call boundaries. -- PiperOrigin-RevId: 248017488 --- mlir/include/mlir/Linalg/IR/LinalgOps.td | 41 ++++++++-- mlir/include/mlir/Linalg/Passes.h | 8 +- mlir/lib/Linalg/IR/LinalgOps.cpp | 76 +++++++++++++------ .../Linalg/Transforms/LowerToLLVMDialect.cpp | 32 ++++++-- mlir/lib/Linalg/Transforms/Tiling.cpp | 3 +- mlir/test/Linalg/llvm.mlir | 7 ++ mlir/test/Linalg/roundtrip.mlir | 11 +++ 7 files changed, 136 insertions(+), 42 deletions(-) diff --git a/mlir/include/mlir/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Linalg/IR/LinalgOps.td index 2aa1e430e275..58eb3f00401f 100644 --- a/mlir/include/mlir/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Linalg/IR/LinalgOps.td @@ -82,10 +82,8 @@ LinalgParametricIntNativeOpTrait<"ViewRanks", ranks> class LinalgOp props> : Op { let arguments = (ins Variadic); // default variadic builder - - let parser = [{ return impl::parseLinalgLibraryOp(parser, result); }]; - - let printer = [{ impl::printLinalgLibraryOp(p, *this); }]; + let parser = [{ return parseLinalgLibraryOp(parser, result); }]; + let printer = [{ printLinalgLibraryOp(p, *this); }]; } def BufferSizeOp : @@ -93,12 +91,39 @@ def BufferSizeOp : Arguments<(ins Buffer)>, Results<(outs Index)> { - let parser = [{ - return impl::parseBufferSizeOp(parser, result); + let parser = [{ return parseBufferSizeOp(parser, result); }]; + let printer = [{ return printBufferSizeOp(p, *this); }]; +} + +def DimOp : Op, + Arguments<(ins View:$view, APIntAttr:$index)>, + Results<(outs Index)> { + let summary = "dimension index operation"; + let description = [{ + The "linalg.dim" operation takes a linalg.view and returns an + "index". It requires a single integer attribute named "index". It + returns the size of the specified dimension. For example: + + %1 = linalg.dim %0, 2 : view }]; - let printer = [{ - return impl::printBufferSizeOp(p, this->getOperation()); + let parser = [{ return parseDimOp(parser, result); }]; + let printer = [{ return printDimOp(p, *this); }]; + let verifier = [{ return ::verify(*this); }]; + + let builders = [OpBuilder< + "Builder *builder, OperationState *result, Value *view," "unsigned index", + [{ + result->addOperands(view); + result->addAttribute( + "index", builder->getIntegerAttr(builder->getIndexType(), index)); + result->types.push_back(builder->getIndexType()); + }]>]; + + let extraClassDeclaration = [{ + unsigned getIndex() { + return getAttrOfType("index").getValue().getZExtValue(); + } }]; } diff --git a/mlir/include/mlir/Linalg/Passes.h b/mlir/include/mlir/Linalg/Passes.h index 931de9095d6f..2825139dce9b 100644 --- a/mlir/include/mlir/Linalg/Passes.h +++ b/mlir/include/mlir/Linalg/Passes.h @@ -28,10 +28,12 @@ namespace mlir { class ModulePassBase; -mlir::ModulePassBase * -createLinalgTilingPass(llvm::ArrayRef tileSizes = {}); +namespace linalg { +ModulePassBase *createLinalgTilingPass(ArrayRef tileSizes = {}); -mlir::ModulePassBase *createLowerLinalgToLLVMPass(); +ModulePassBase *createLowerLinalgToLLVMPass(); + +} // namespace linalg } // namespace mlir #endif // MLIR_LINALG_PASSES_H_ diff --git a/mlir/lib/Linalg/IR/LinalgOps.cpp b/mlir/lib/Linalg/IR/LinalgOps.cpp index da102b3f394f..e6e18bb197e3 100644 --- a/mlir/lib/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Linalg/IR/LinalgOps.cpp @@ -488,30 +488,19 @@ void mlir::linalg::ViewOp::print(OpAsmPrinter *p) { *p << "] : " << getType(); } -namespace mlir { -namespace linalg { -namespace impl { -void printLinalgLibraryOp(OpAsmPrinter *p, Operation *op); -ParseResult parseLinalgLibraryOp(OpAsmParser *parser, OperationState *result); -void printBufferSizeOp(OpAsmPrinter *p, Operation *op); -ParseResult parseBufferSizeOp(OpAsmParser *parser, OperationState *result); -} // namespace impl -} // namespace linalg - /// Buffer size prints as: /// /// ``` {.mlir} /// %0 = linalg.buffer_size %arg0 : !linalg.buffer /// ``` -void mlir::linalg::impl::printBufferSizeOp(OpAsmPrinter *p, Operation *op) { - assert(op->getAbstractOperation() && "unregistered operation"); - *p << cast(op).getOperationName() << " " << *op->getOperand(0); - p->printOptionalAttrDict(op->getAttrs()); - *p << " : " << op->getOperand(0)->getType(); +static void printBufferSizeOp(OpAsmPrinter *p, BufferSizeOp op) { + *p << op.getOperationName() << " " << *op.getOperand(); + p->printOptionalAttrDict(op.getAttrs()); + *p << " : " << op.getOperand()->getType(); } -ParseResult mlir::linalg::impl::parseBufferSizeOp(OpAsmParser *parser, - OperationState *result) { +static ParseResult parseBufferSizeOp(OpAsmParser *parser, + OperationState *result) { OpAsmParser::OperandType op; Type type; return failure(parser->parseOperand(op) || @@ -522,10 +511,44 @@ ParseResult mlir::linalg::impl::parseBufferSizeOp(OpAsmParser *parser, result->types)); } -#define GET_OP_CLASSES -#include "mlir/Linalg/IR/LinalgOps.cpp.inc" +static void printDimOp(OpAsmPrinter *p, DimOp op) { + *p << op.getOperationName() << " " << *op.getOperand() << ", " + << op.getIndex(); + p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"index"}); + *p << " : " << op.getOperand()->getType(); +} -} // namespace mlir +static ParseResult parseDimOp(OpAsmParser *parser, OperationState *result) { + OpAsmParser::OperandType operandInfo; + IntegerAttr indexAttr; + Type type; + Type indexType = parser->getBuilder().getIndexType(); + return failure(parser->parseOperand(operandInfo) || parser->parseComma() || + parser->parseAttribute(indexAttr, indexType, "index", + result->attributes) || + parser->parseOptionalAttributeDict(result->attributes) || + parser->parseColonType(type) || + parser->resolveOperand(operandInfo, type, result->operands) || + parser->addTypeToList(indexType, result->types)); +} + +static LogicalResult verify(linalg::DimOp op) { + // Check that we have an integer index operand. + auto indexAttr = op.getAttrOfType("index"); + if (!indexAttr) + return op.emitOpError("requires an integer attribute named 'index'"); + + uint64_t index = indexAttr.getValue().getZExtValue(); + auto type = op.getOperand()->getType(); + if (auto viewType = type.dyn_cast()) { + if (index >= viewType.getRank()) + return op.emitOpError("index is out of range"); + } else { + return op.emitOpError("requires an operand with view type"); + } + + return success(); +} // A LinalgLibraryOp prints as: // @@ -541,7 +564,7 @@ ParseResult mlir::linalg::impl::parseBufferSizeOp(OpAsmParser *parser, // ``` // // Where %0, %1 and %2 are ssa-values of type ViewType. -void mlir::linalg::impl::printLinalgLibraryOp(OpAsmPrinter *p, Operation *op) { +static void printLinalgLibraryOp(OpAsmPrinter *p, Operation *op) { assert(op->getAbstractOperation() && "unregistered operation"); *p << op->getName().getStringRef() << "("; interleave( @@ -553,8 +576,8 @@ void mlir::linalg::impl::printLinalgLibraryOp(OpAsmPrinter *p, Operation *op) { [&](Value *v) { *p << v->getType(); }, [&]() { *p << ", "; }); } -ParseResult mlir::linalg::impl::parseLinalgLibraryOp(OpAsmParser *parser, - OperationState *result) { +static ParseResult parseLinalgLibraryOp(OpAsmParser *parser, + OperationState *result) { SmallVector ops; SmallVector types; return failure( @@ -565,6 +588,13 @@ ParseResult mlir::linalg::impl::parseLinalgLibraryOp(OpAsmParser *parser, result->operands)); } +namespace mlir { + +#define GET_OP_CLASSES +#include "mlir/Linalg/IR/LinalgOps.cpp.inc" + +} // namespace mlir + // Ideally this should all be Tablegen'd but there is no good story for // AffineMap for now. SmallVector mlir::linalg::loopToOperandRangesMaps(Operation *op) { diff --git a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp index 2d1f5f22da53..6c4d5c2104bc 100644 --- a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp +++ b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp @@ -154,7 +154,7 @@ static ArrayAttr makePositionAttr(FuncBuilder &builder, return builder.getArrayAttr(attrs); } -// BufferAllocOp creates a new `index` value. +// BufferAllocOp creates a new `!linalg.buffer` value. class BufferAllocOpConversion : public LLVMOpLowering { public: explicit BufferAllocOpConversion(MLIRContext *context, @@ -213,7 +213,7 @@ public: } }; -// BufferDeallocOp creates a new `index` value. +// BufferDeallocOp creates no value. class BufferDeallocOpConversion : public LLVMOpLowering { public: explicit BufferDeallocOpConversion(MLIRContext *context, @@ -268,6 +268,23 @@ public: } }; +// DimOp creates a new `index` value. +class DimOpConversion : public LLVMOpLowering { +public: + explicit DimOpConversion(MLIRContext *context, LLVMLowering &lowering_) + : LLVMOpLowering(linalg::DimOp::getOperationName(), context, lowering_) {} + + SmallVector rewrite(Operation *op, ArrayRef operands, + FuncBuilder &rewriter) const override { + auto dimOp = cast(op); + auto indexTy = lowering.convertType(rewriter.getIndexType()); + edsc::ScopedContext context(rewriter, op->getLoc()); + return {extractvalue( + indexTy, operands[0], + makePositionAttr(rewriter, {2, static_cast(dimOp.getIndex())}))}; + } +}; + namespace { // Common functionality for Linalg LoadOp and StoreOp conversion to the // LLVM IR Dialect. @@ -533,10 +550,11 @@ protected: llvm::DenseSet initAdditionalConverters() override { return ConversionListBuilder< BufferAllocOpConversion, BufferDeallocOpConversion, - BufferSizeOpConversion, DotOpConversion, LoadOpConversion, - RangeOpConversion, SliceOpConversion, StoreOpConversion, - ViewOpConversion>::build(&converterStorage, llvmDialect->getContext(), - *this); + BufferSizeOpConversion, DimOpConversion, DotOpConversion, + LoadOpConversion, RangeOpConversion, SliceOpConversion, + StoreOpConversion, ViewOpConversion>::build(&converterStorage, + llvmDialect->getContext(), + *this); } Type convertAdditionalType(Type t) override { @@ -564,7 +582,7 @@ void LowerLinalgToLLVMPass::runOnModule() { signalPassFailure(); } -ModulePassBase *mlir::createLowerLinalgToLLVMPass() { +ModulePassBase *mlir::linalg::createLowerLinalgToLLVMPass() { return new LowerLinalgToLLVMPass(); } diff --git a/mlir/lib/Linalg/Transforms/Tiling.cpp b/mlir/lib/Linalg/Transforms/Tiling.cpp index e1fa74da698b..f50076a1710a 100644 --- a/mlir/lib/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Linalg/Transforms/Tiling.cpp @@ -360,7 +360,8 @@ LinalgTilingPass::LinalgTilingPass(ArrayRef sizes) this->tileSizes.assign(sizes.begin(), sizes.end()); } -ModulePassBase *mlir::createLinalgTilingPass(ArrayRef tileSizes) { +ModulePassBase * +mlir::linalg::createLinalgTilingPass(ArrayRef tileSizes) { return new LinalgTilingPass(tileSizes); } diff --git a/mlir/test/Linalg/llvm.mlir b/mlir/test/Linalg/llvm.mlir index 321334267b6b..143b851c2454 100644 --- a/mlir/test/Linalg/llvm.mlir +++ b/mlir/test/Linalg/llvm.mlir @@ -73,3 +73,10 @@ func @dot(%arg0: !linalg.view, %arg1: !linalg.view, %arg2: !linalg } // CHECK-LABEL: func @dot(%arg0: !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">, %arg1: !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">, %arg2: !llvm<"{ float*, i64, [0 x i64], [0 x i64] }">) { // CHECK: llvm.call @linalg_dot(%arg0, %arg1, %arg2) : (!llvm<"{ float*, i64, [1 x i64], [1 x i64] }">, !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">, !llvm<"{ float*, i64, [0 x i64], [0 x i64] }">) -> () + +func @dim(%arg0: !linalg.view) { + %0 = linalg.dim %arg0, 1 : !linalg.view + return +} +// CHECK-LABEL: func @dim(%arg0: !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">) { +// CHECK: %0 = llvm.extractvalue %arg0[2, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> diff --git a/mlir/test/Linalg/roundtrip.mlir b/mlir/test/Linalg/roundtrip.mlir index 13e360476f41..c2eed72876eb 100644 --- a/mlir/test/Linalg/roundtrip.mlir +++ b/mlir/test/Linalg/roundtrip.mlir @@ -52,3 +52,14 @@ func @ops(%arg0: !linalg.view, %arg1: !linalg.view, %arg2: !lina // CHECK-NEXT: linalg.matvec(%arg0, %arg1, %arg2) : !linalg.view, !linalg.view, !linalg.view // CHECK-NEXT: linalg.dot(%arg1, %arg2, %arg3) : !linalg.view, !linalg.view, !linalg.view +func @dim(%arg0: !linalg.view) { + %0 = linalg.dim %arg0, 1 : !linalg.view + %1 = linalg.buffer_alloc %0 : !linalg.buffer + linalg.buffer_dealloc %1 : !linalg.buffer + return +} +// CHECK-LABEL: func @dim(%arg0: !linalg.view) { +// CHECK-NEXT: %0 = linalg.dim %arg0, 1 : !linalg.view +// CHECK-NEXT: %1 = linalg.buffer_alloc %0 : !linalg.buffer +// CHECK-NEXT: linalg.buffer_dealloc %1 : !linalg.buffer +