From c38dca7f4b697f9876b165acc9a6704f756d1173 Mon Sep 17 00:00:00 2001 From: Andy Davis Date: Wed, 6 Nov 2019 08:53:39 -0800 Subject: [PATCH] Add ViewOp to the StandardOps dialect, which casts a 1D/i8 element type memref type to an N-D memref type. Proposed in RFC: https://groups.google.com/a/tensorflow.org/forum/#!searchin/mlir/std.view%7Csort:date/mlir/-wKHANzDNTg/4K6nUAp8AAAJ Supports creating the N-D memref type with dynamic sizes and at a dynamic offset within the 1D base memref. This change contains op definition/parsing/printing and tests. Follow up changes will handle constant shape/layout map folding and llvm lowering. PiperOrigin-RevId: 278869990 --- mlir/include/mlir/Dialect/StandardOps/Ops.td | 62 +++++++++++++++ mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 2 +- .../Linalg/Transforms/LowerToLLVMDialect.cpp | 7 +- mlir/lib/Dialect/StandardOps/Ops.cpp | 77 +++++++++++++++++++ mlir/test/IR/core-ops.mlir | 32 ++++++++ mlir/test/IR/invalid-ops.mlir | 52 +++++++++++++ 6 files changed, 228 insertions(+), 4 deletions(-) diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.td b/mlir/include/mlir/Dialect/StandardOps/Ops.td index 6018c30397bf..4dbdcc47ff03 100644 --- a/mlir/include/mlir/Dialect/StandardOps/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/Ops.td @@ -1133,6 +1133,68 @@ def TruncateIOp : Std_Op<"trunci", [NoSideEffect, SameOperandsAndResultShape]> { }]; } +def ViewOp : Std_Op<"view"> { + let summary = "memref view operation"; + let description = [{ + The "view" operation converts a 1-D memref with i8 element type, + to an N-D memref with arbitrary element type. In addition, the ViewOp + supports the following arguments: + *) A dynamic size operand must be specified for each dynamic dimension + in the resulting view memref type. + *) A single dynamic offset operand can be specified which represents a + a dynamic offset within the base 1-D memref at which to create the + resulting memref view. + + // Allocate a flat 1D/i8 memref. + %0 = alloc() : memref<2048xi8> + + // ViewOp with static sizes and offset. + %1 = view %0[][] : memref<2048xi8> to memref<64x4xf32> + + // ViewOp with one dynamic size and a dynamic offset. + %2 = view %0[%size0][%offset_1024] + : memref<2048xi8> to memref (d0 * 4 + d1 + s0) + + // ViewOp creating 3D shape where two of the dim sizes are dynamic. + // *) The dynamic size for the second dimension induces a dynamic + // stride for the first dimension, which is represented by the + // symbol 's0' in the layout map of the ViewOp result memref type. + // Note that this dynamic stride will be computed from the view + // shape and dynamic sizes. + // *) The dynamic offset specified in the ViewOp is applied to the + // base 1-D memref, and is represented by the symbol 's1' in the + // layout map of the ViewOp result memref type. + %3 = view %0[%size0, %size1][%offset_1024] + : memref<2048xi8> to memref (d0 * s0 + d1 * 4 + d2 + s1) + }]; + + let arguments = (ins AnyMemRef:$source, Variadic:$operands); + let results = (outs AnyMemRef); + + let extraClassDeclaration = [{ + /// The result of a memref_shape_cast is always a memref. + MemRefType getType() { return getResult()->getType().cast(); } + + /// Returns the dynamic offset for this shape cast operation if specified. + /// Returns nullptr if no dynamic offset was specified. + Value *getDynamicOffset() { + unsigned offsetPos = 1 + getType().getNumDynamicDims(); + if (offsetPos >= getNumOperands()) + return nullptr; + return getOperand(offsetPos); + } + + /// Returns the dynamic sizes for this shape cast operation. + operand_range getDynamicSizes() { + return {operand_begin() + 1, + operand_begin() + 1 + getType().getNumDynamicDims()}; + } + }]; + // TODO(andydavis) Add canonicalizer to fold constants into shape and map. +} + + def XOrOp : IntArithmeticOp<"xor", [Commutative]> { let summary = "integer binary xor"; let hasFolder = 1; diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index f18f9a3626dd..1c934c20069c 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -529,7 +529,7 @@ void mlir::linalg::ViewOp::build(Builder *b, OperationState &result, result.addAttributes(attrs); } -static void print(OpAsmPrinter &p, ViewOp op) { +static void print(OpAsmPrinter &p, mlir::linalg::ViewOp op) { p << op.getOperationName() << " " << *op.buffer() << "["; interleaveComma(op.ranges(), p, [&](Value *v) { p << *v; }); p << "] "; diff --git a/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp index 7adf589d8887..3de6dc6b5010 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp @@ -490,14 +490,15 @@ public: class ViewOpConversion : public LLVMOpLowering { public: explicit ViewOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) - : LLVMOpLowering(ViewOp::getOperationName(), context, lowering_) {} + : LLVMOpLowering(mlir::linalg::ViewOp::getOperationName(), context, + lowering_) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - ViewOpOperandAdaptor adaptor(operands); + mlir::linalg::ViewOpOperandAdaptor adaptor(operands); - auto viewOp = cast(op); + auto viewOp = cast(op); BaseViewConversionHelper helper(op->getLoc(), viewOp.getViewType(), rewriter, lowering); LLVMType elementTy = helper.elementTy, int64Ty = helper.int64Ty; diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp index 22309eb8f539..e6b99035f6ec 100644 --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -2340,6 +2340,83 @@ static LogicalResult verify(TruncateIOp op) { return success(); } +//===----------------------------------------------------------------------===// +// ViewOp +//===----------------------------------------------------------------------===// + +static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) { + OpAsmParser::OperandType srcInfo; + SmallVector offsetInfo; + SmallVector sizesInfo; + auto indexType = parser.getBuilder().getIndexType(); + Type srcType, dstType; + return failure( + parser.parseOperand(srcInfo) || + parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) || + parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(srcType) || + parser.resolveOperand(srcInfo, srcType, result.operands) || + parser.resolveOperands(sizesInfo, indexType, result.operands) || + parser.resolveOperands(offsetInfo, indexType, result.operands) || + parser.parseKeywordType("to", dstType) || + parser.addTypeToList(dstType, result.types)); +} + +static void print(OpAsmPrinter &p, ViewOp op) { + p << op.getOperationName() << ' ' << *op.getOperand(0) << '['; + p.printOperands(op.getDynamicSizes()); + p << "]["; + auto *dynamicOffset = op.getDynamicOffset(); + if (dynamicOffset != nullptr) + p.printOperand(dynamicOffset); + p << ']'; + p.printOptionalAttrDict(op.getAttrs()); + p << " : " << op.getOperand(0)->getType() << " to " << op.getType(); +} + +static LogicalResult verify(ViewOp op) { + auto baseType = op.getOperand(0)->getType().dyn_cast(); + auto viewType = op.getResult()->getType().dyn_cast(); + + // Operand 0 type and ViewOp result type must be memref. + if (!baseType || !viewType) + return op.emitError("operand type ") << baseType << " and result type " + << viewType << " are must be memref"; + + // The base memref should be rank 1 with i8 element type. + if (baseType.getRank() != 1 || !baseType.getElementType().isInteger(8)) + return op.emitError("unsupported shape for base memref type ") << baseType; + + // The base memref should have identity layout map (or none). + if (baseType.getAffineMaps().size() > 1 || + (baseType.getAffineMaps().size() == 1 && + !baseType.getAffineMaps()[0].isIdentity())) + return op.emitError("unsupported map for base memref type ") << baseType; + + // The base memref and the view memref should be in the same memory space. + if (baseType.getMemorySpace() != viewType.getMemorySpace()) + return op.emitError("different memory spaces specified for base memref " + "type ") + << baseType << " and view memref type " << viewType; + + // Verify that the result memref type has a strided layout map. is strided + int64_t offset; + llvm::SmallVector strides; + if (failed(mlir::getStridesAndOffset(viewType, strides, offset))) + return op.emitError("result type ") << viewType << " is not strided"; + + // Verify that we have the correct number of operands for the result type. + unsigned memrefOperandCount = 1; + unsigned numDynamicDims = viewType.getNumDynamicDims(); + unsigned dynamicOffsetCount = + offset == MemRefType::getDynamicStrideOrOffset() ? 1 : 0; + if (op.getNumOperands() != + memrefOperandCount + numDynamicDims + dynamicOffsetCount) + return op.emitError("incorrect number of operands for type ") << viewType; + return success(); +} + //===----------------------------------------------------------------------===// // ZeroExtendIOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir index 417068a7facf..977ec6616469 100644 --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -10,6 +10,8 @@ // CHECK-DAG: #[[map_proj_d0d1_d0:map[0-9]+]] = (d0, d1) -> (d0) // CHECK-DAG: #[[map_proj_d0d1_d1:map[0-9]+]] = (d0, d1) -> (d1) // CHECK-DAG: #[[map_proj_d0d1_d1d0:map[0-9]+]] = (d0, d1) -> (d1, d0) +// CHECK-DAG: #[[VIEW_MAP0:map[0-9]+]] = (d0, d1)[s0] -> (d0 * 4 + d1 + s0) +// CHECK-DAG: #[[VIEW_MAP1:map[0-9]+]] = (d0, d1) -> (d0 * 4 + d1) // CHECK-LABEL: func @func_with_ops(%arg0: f32) { func @func_with_ops(f32) { @@ -472,6 +474,36 @@ func @memref_cast(%arg0: memref<4xf32>, %arg1 : memref) { return } +// CHECK-LABEL: func @memref_view(%arg0 +func @memref_view(%arg0 : index, %arg1 : index, %arg2 : index) { + %0 = alloc() : memref<2048xi8> + // Test two dynamic sizes and dynamic offset. + // CHECK: %{{.*}} = std.view %0[%arg0, %arg1][%arg2] : memref<2048xi8> to memref + %1 = view %0[%arg0, %arg1][%arg2] + : memref<2048xi8> to memref (d0 * 4 + d1 + s0)> + + // Test two dynamic sizes and static offset. + // CHECK: %{{.*}} = std.view %0[%arg0, %arg1][] : memref<2048xi8> to memref + %2 = view %0[%arg0, %arg1][] + : memref<2048xi8> to memref (d0 * 4 + d1)> + + // Test one dynamic size and dynamic offset. + // CHECK: %{{.*}} = std.view %0[%arg1][%arg2] : memref<2048xi8> to memref<4x?xf32, #[[VIEW_MAP0]]> + %3 = view %0[%arg1][%arg2] + : memref<2048xi8> to memref<4x?xf32, (d0, d1)[s0] -> (d0 * 4 + d1 + s0)> + + // Test one dynamic size and static offset. + // CHECK: %{{.*}} = std.view %0[%arg0][] : memref<2048xi8> to memref + %4 = view %0[%arg0][] + : memref<2048xi8> to memref (d0 * 4 + d1)> + + // Test static sizes and static offset. + // CHECK: %{{.*}} = std.view %0[][] : memref<2048xi8> to memref<64x4xf32, #[[VIEW_MAP1]]> + %5 = view %0[][] + : memref<2048xi8> to memref<64x4xf32, (d0, d1) -> (d0 * 4 + d1)> + return +} + // CHECK-LABEL: func @test_dimop(%arg0 func @test_dimop(%arg0: tensor<4x4x?xf32>) { // CHECK: %0 = dim %arg0, 2 : tensor<4x4x?xf32> diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir index be44a6bff396..4d45d95222d3 100644 --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -901,3 +901,55 @@ func @invalid_splat(%v : f32) { // expected-note {{prior use here}} // expected-error@-1 {{expects different type than prior uses}} return } + +// ----- + +func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) { + %0 = alloc() : memref<2048xi8> + // expected-error@+1 {{incorrect number of operands for type}} + %1 = view %0[%arg0, %arg1][] + : memref<2048xi8> to memref (d0 * 4 + d1 + s0)> + return +} + +// ----- + +func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) { + %0 = alloc() : memref<2048xi8> + // expected-error@+1 {{is not strided}} + %1 = view %0[%arg0, %arg1][] + : memref<2048xi8> to memref (d0, d1, s0)> + return +} + +// ----- + +func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) { + %0 = alloc() : memref<2048xf32> + // expected-error@+1 {{unsupported shape for base memref}} + %1 = view %0[%arg0, %arg1][] + : memref<2048xf32> to memref (d0 * 4 + d1 + s0)> + return +} + +// ----- + +func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) { + %0 = alloc() : memref<2048xi8, (d0) -> (d0 floordiv 8, d0 mod 8)> + // expected-error@+1 {{unsupported map for base memref}} + %1 = view %0[%arg0, %arg1][] + : memref<2048xi8, (d0) -> (d0 floordiv 8, d0 mod 8)> to + memref (d0 * 4 + d1 + s0)> + return +} + +// ----- + +func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) { + %0 = alloc() : memref<2048xi8, 2> + // expected-error@+1 {{different memory spaces}} + %1 = view %0[%arg0, %arg1][] + : memref<2048xi8, 2> to + memref (d0 * 4 + d1 + s0), 1> + return +}