From 993e79e9bd132d344f4b79d44055c6d49f072a00 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Tue, 3 Dec 2019 06:22:31 -0800 Subject: [PATCH] Fix ViewOp to have at most one offset operand As described in the documentation, ViewOp is expected to take an optional dynamic offset followed by a list of dynamic sizes. However, the ViewOp parser did not include a check for the offset being a single value and accepeted a list of values instead. Furthermore, several tests have been exercising the wrong syntax of a ViewOp, passing multiple values to the dyanmic stride list, which was not caught by the parser. The trailing values could have been erronously interpreted as dynamic sizes. This is likely due to resyntaxing of the ViewOp, with the previous syntax taking the list of sizes before the offset. Update the tests to use the syntax with the offset preceding the sizes. Worse, the conversion of ViewOp to the LLVM dialect assumed the wrong order of operands with offset in the trailing position, and erronously relied on the permissive parsing that interpreted trailing dynamic offset values as leading dynamic sizes. Fix the lowering to use the correct order of operands. PiperOrigin-RevId: 283532506 --- .../StandardToLLVM/ConvertStandardToLLVM.cpp | 17 +++++++++++------ mlir/lib/Dialect/StandardOps/Ops.cpp | 10 ++++++++-- .../StandardToLLVM/convert-to-llvmir.mlir | 10 +++++----- mlir/test/Dialect/Linalg/loops.mlir | 16 ++++++++-------- mlir/test/Dialect/Linalg/promote.mlir | 6 +++--- mlir/test/Dialect/Linalg/roundtrip.mlir | 4 ++-- mlir/test/IR/invalid-ops.mlir | 13 ++++++++++++- .../linalg_integration_test.mlir | 6 +++--- 8 files changed, 52 insertions(+), 30 deletions(-) diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index 0d9322088936..793997e90452 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -1663,13 +1663,14 @@ struct ViewOpLowering : public LLVMLegalizationPattern { // Field 3: Copy the offset in aligned pointer. unsigned numDynamicSizes = llvm::size(viewOp.getDynamicSizes()); (void)numDynamicSizes; + bool hasDynamicOffset = offset == MemRefType::getDynamicStrideOrOffset(); auto sizeAndOffsetOperands = adaptor.operands(); - assert(llvm::size(sizeAndOffsetOperands) == numDynamicSizes + 1 || - offset != MemRefType::getDynamicStrideOrOffset()); - Value *baseOffset = (offset != MemRefType::getDynamicStrideOrOffset()) + assert(llvm::size(sizeAndOffsetOperands) == + numDynamicSizes + (hasDynamicOffset ? 1 : 0)); + Value *baseOffset = !hasDynamicOffset ? createIndexConstant(rewriter, loc, offset) // TODO(ntv): better adaptor. - : sizeAndOffsetOperands.back(); + : sizeAndOffsetOperands.front(); targetMemRef.setOffset(rewriter, loc, baseOffset); // Early exit for 0-D corner case. @@ -1681,10 +1682,14 @@ struct ViewOpLowering : public LLVMLegalizationPattern { return op->emitWarning("cannot cast to non-contiguous shape"), matchFailure(); Value *stride = nullptr, *nextSize = nullptr; + // Drop the dynamic stride from the operand list, if present. + ArrayRef sizeOperands(sizeAndOffsetOperands); + if (hasDynamicOffset) + sizeOperands = sizeOperands.drop_front(); for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { // Update size. - Value *size = getSize(rewriter, loc, viewMemRefType.getShape(), - sizeAndOffsetOperands, i); + Value *size = + getSize(rewriter, loc, viewMemRefType.getShape(), sizeOperands, i); targetMemRef.setSize(rewriter, loc, i, size); // Update stride. stride = getStride(rewriter, loc, strides, nextSize, stride, i); diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp index 31431be50543..361135c4e297 100644 --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -2327,9 +2327,15 @@ static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) { SmallVector sizesInfo; auto indexType = parser.getBuilder().getIndexType(); Type srcType, dstType; + llvm::SMLoc offsetLoc; + if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) || + parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square)) + return failure(); + + if (offsetInfo.size() > 1) + return parser.emitError(offsetLoc) << "expects 0 or 1 offset operand"; + return failure( - parser.parseOperand(srcInfo) || - parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square) || parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(srcType) || diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir index 5c50ed8fb40b..a22448017a22 100644 --- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir @@ -621,7 +621,7 @@ func @view(%arg0 : index, %arg1 : index, %arg2 : index) { // CHECK: llvm.insertvalue %[[ARG0]], %{{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> // CHECK: llvm.mul %{{.*}}, %[[ARG1]] // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> - %1 = view %0[%arg0, %arg1][%arg2] + %1 = view %0[%arg2][%arg0, %arg1] : memref<2048xi8> to memref (d0 * s0 + d1 + s1)> // Test two dynamic sizes and static offset. @@ -637,7 +637,7 @@ func @view(%arg0 : index, %arg1 : index, %arg2 : index) { // CHECK: llvm.insertvalue %arg0, %{{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> // CHECK: llvm.mul %{{.*}}, %[[ARG1]] // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> - %2 = view %0[%arg0, %arg1][] + %2 = view %0[][%arg0, %arg1] : memref<2048xi8> to memref (d0 * s0 + d1)> // Test one dynamic size and dynamic offset. @@ -653,7 +653,7 @@ func @view(%arg0 : index, %arg1 : index, %arg2 : index) { // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> // CHECK: llvm.mul %{{.*}}, %[[ARG1]] // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> - %3 = view %0[%arg1][%arg2] + %3 = view %0[%arg2][%arg1] : memref<2048xi8> to memref<4x?xf32, (d0, d1)[s0, s1] -> (d0 * s0 + d1 + s1)> // Test one dynamic size and static offset. @@ -670,7 +670,7 @@ func @view(%arg0 : index, %arg1 : index, %arg2 : index) { // CHECK: llvm.insertvalue %[[ARG0]], %{{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> // CHECK: llvm.mlir.constant(4 : index) : !llvm.i64 // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> - %4 = view %0[%arg0][] + %4 = view %0[][%arg0] : memref<2048xi8> to memref (d0 * 4 + d1)> // Test static sizes and static offset. @@ -703,7 +703,7 @@ func @view(%arg0 : index, %arg1 : index, %arg2 : index) { // CHECK: llvm.insertvalue %[[ARG0]], %{{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> // CHECK: llvm.mul %[[STRIDE_1]], %[[ARG1]] : !llvm.i64 // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> - %6 = view %0[%arg0, %arg1][%arg2] + %6 = view %0[%arg2][%arg0, %arg1] : memref<2048xi8> to memref (d0 * s0 + d1 + s1)> return diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir index 1e1a6270d7e8..7fa5594c9b51 100644 --- a/mlir/test/Dialect/Linalg/loops.mlir +++ b/mlir/test/Dialect/Linalg/loops.mlir @@ -16,9 +16,9 @@ func @matmul(%arg0: memref, %M: index, %N: index, %K: index) { %c0 = constant 0 : index %c1 = constant 1 : index - %A = view %arg0[%M, %K][%c0] : memref to memref - %B = view %arg0[%K, %N][%c0] : memref to memref - %C = view %arg0[%M, %N][%c0] : memref to memref + %A = view %arg0[%c0][%M, %K] : memref to memref + %B = view %arg0[%c0][%K, %N] : memref to memref + %C = view %arg0[%c0][%M, %N] : memref to memref linalg.matmul(%A, %B, %C) : memref, memref, memref return } @@ -42,9 +42,9 @@ func @matmul(%arg0: memref, %M: index, %N: index, %K: index) { func @matvec(%arg0: memref, %M: index, %N: index) { %c0 = constant 0 : index %c1 = constant 1 : index - %2 = view %arg0[%M, %N][%c0] : memref to memref - %3 = view %arg0[%M][%c0] : memref to memref - %4 = view %arg0[%N][%c0] : memref to memref + %2 = view %arg0[%c0][%M, %N] : memref to memref + %3 = view %arg0[%c0][%M] : memref to memref + %4 = view %arg0[%c0][%N] : memref to memref linalg.matvec(%2, %3, %4) : memref, memref, memref return } @@ -66,8 +66,8 @@ func @matvec(%arg0: memref, %M: index, %N: index) { func @dot(%arg0: memref, %M: index) { %c0 = constant 0 : index %c1 = constant 1 : index - %1 = view %arg0[%M][%c0] : memref to memref - %2 = view %arg0[%M][%c0] : memref to memref + %1 = view %arg0[%c0][%M] : memref to memref + %2 = view %arg0[%c0][%M] : memref to memref %3 = view %arg0[][] : memref to memref linalg.dot(%1, %2, %3) : memref, memref, memref return diff --git a/mlir/test/Dialect/Linalg/promote.mlir b/mlir/test/Dialect/Linalg/promote.mlir index 61d11f97bd80..51261fcc37b8 100644 --- a/mlir/test/Dialect/Linalg/promote.mlir +++ b/mlir/test/Dialect/Linalg/promote.mlir @@ -17,9 +17,9 @@ module { %c2 = constant 2 : index %c0 = constant 0 : index %c1 = constant 1 : index - %3 = view %A[%M, %K][%c0] : memref to memref - %4 = view %A[%K, %N][%c0] : memref to memref - %5 = view %A[%M, %N][%c0] : memref to memref + %3 = view %A[%c0][%M, %K] : memref to memref + %4 = view %A[%c0][%K, %N] : memref to memref + %5 = view %A[%c0][%M, %N] : memref to memref %6 = dim %3, 0 : memref %7 = dim %3, 1 : memref %8 = dim %4, 1 : memref diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir index b53e674368fa..29e04aba33ac 100644 --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -25,12 +25,12 @@ func @views(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index %0 = muli %arg0, %arg0 : index %1 = alloc (%0) : memref %2 = linalg.range %arg0:%arg1:%arg2 : !linalg.range - %3 = view %1[%arg0, %arg0][%c0] : memref to memref + %3 = view %1[%c0][%arg0, %arg0] : memref to memref %4 = linalg.slice %3[%2, %2] : memref, !linalg.range, !linalg.range, memref %5 = linalg.slice %3[%2, %arg2] : memref, !linalg.range, index, memref %6 = linalg.slice %3[%arg2, %2] : memref, index, !linalg.range, memref %7 = linalg.slice %3[%arg2, %arg3] : memref, index, index, memref - %8 = view %1[%arg0, %arg0][%c0] : memref to memref, offset: ?, strides: [?, 1]> + %8 = view %1[%c0][%arg0, %arg0] : memref to memref, offset: ?, strides: [?, 1]> dealloc %1 : memref return } diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir index bd8b0cdf6771..8b9dba975026 100644 --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -780,7 +780,7 @@ func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) { func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) { %0 = alloc() : memref<2048xi8> // expected-error@+1 {{incorrect dynamic strides}} - %1 = view %0[%arg0, %arg1][] + %1 = view %0[][%arg0, %arg1] : memref<2048xi8> to memref (d0 * 777 + d1 * 4 + d2)> return @@ -799,6 +799,17 @@ func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) { // ----- +func @multiple_offsets(%arg0: index) { + %0 = alloc() : memref<2048xi8> + // expected-error@+1 {{expects 0 or 1 offset operand}} + %1 = view %0[%arg0, %arg0][%arg0] + : memref<2048xi8> to + memref (d0 * 777 + d1 * 4 + d2)> + return +} + +// ----- + func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { %0 = alloc() : memref<8x16x4xf32, (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2), 2> // expected-error@+1 {{different memory spaces}} diff --git a/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir b/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir index d1ee472850ab..4fce008ae82d 100644 --- a/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir +++ b/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir @@ -65,9 +65,9 @@ func @matmul() -> f32 { %bB = call @alloc_filled_f32(%c160, %f1) : (index, f32) -> (memref) %bC = call @alloc_filled_f32(%c100, %f10) : (index, f32) -> (memref) - %A = view %bA[%c10, %c16][] : memref to memref - %B = view %bB[%c16, %c10][] : memref to memref - %C = view %bC[%c10, %c10][] : memref to memref + %A = view %bA[][%c10, %c16] : memref to memref + %B = view %bB[][%c16, %c10] : memref to memref + %C = view %bC[][%c10, %c10] : memref to memref linalg.matmul(%A, %B, %C) : memref, memref, memref %res = load %C[%c6, %c7] : memref