Fix ViewOp to have at most one offset operand

As described in the documentation, ViewOp is expected to take an optional
dynamic offset followed by a list of dynamic sizes. However, the ViewOp parser
did not include a check for the offset being a single value and accepeted a
list of values instead.

Furthermore, several tests have been exercising the wrong syntax of a ViewOp,
passing multiple values to the dyanmic stride list, which was not caught by the
parser. The trailing values could have been erronously interpreted as dynamic
sizes. This is likely due to resyntaxing of the ViewOp, with the previous
syntax taking the list of sizes before the offset. Update the tests to use the
syntax with the offset preceding the sizes.

Worse, the conversion of ViewOp to the LLVM dialect assumed the wrong order of
operands with offset in the trailing position, and erronously relied on the
permissive parsing that interpreted trailing dynamic offset values as leading
dynamic sizes. Fix the lowering to use the correct order of operands.

PiperOrigin-RevId: 283532506
This commit is contained in:
Alex Zinenko 2019-12-03 06:22:31 -08:00 committed by A. Unique TensorFlower
parent 330d1ff00e
commit 993e79e9bd
8 changed files with 52 additions and 30 deletions

View File

@ -1663,13 +1663,14 @@ struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> {
// Field 3: Copy the offset in aligned pointer.
unsigned numDynamicSizes = llvm::size(viewOp.getDynamicSizes());
(void)numDynamicSizes;
bool hasDynamicOffset = offset == MemRefType::getDynamicStrideOrOffset();
auto sizeAndOffsetOperands = adaptor.operands();
assert(llvm::size(sizeAndOffsetOperands) == numDynamicSizes + 1 ||
offset != MemRefType::getDynamicStrideOrOffset());
Value *baseOffset = (offset != MemRefType::getDynamicStrideOrOffset())
assert(llvm::size(sizeAndOffsetOperands) ==
numDynamicSizes + (hasDynamicOffset ? 1 : 0));
Value *baseOffset = !hasDynamicOffset
? createIndexConstant(rewriter, loc, offset)
// TODO(ntv): better adaptor.
: sizeAndOffsetOperands.back();
: sizeAndOffsetOperands.front();
targetMemRef.setOffset(rewriter, loc, baseOffset);
// Early exit for 0-D corner case.
@ -1681,10 +1682,14 @@ struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> {
return op->emitWarning("cannot cast to non-contiguous shape"),
matchFailure();
Value *stride = nullptr, *nextSize = nullptr;
// Drop the dynamic stride from the operand list, if present.
ArrayRef<Value *> sizeOperands(sizeAndOffsetOperands);
if (hasDynamicOffset)
sizeOperands = sizeOperands.drop_front();
for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
// Update size.
Value *size = getSize(rewriter, loc, viewMemRefType.getShape(),
sizeAndOffsetOperands, i);
Value *size =
getSize(rewriter, loc, viewMemRefType.getShape(), sizeOperands, i);
targetMemRef.setSize(rewriter, loc, i, size);
// Update stride.
stride = getStride(rewriter, loc, strides, nextSize, stride, i);

View File

@ -2327,9 +2327,15 @@ static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::OperandType, 4> sizesInfo;
auto indexType = parser.getBuilder().getIndexType();
Type srcType, dstType;
llvm::SMLoc offsetLoc;
if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) ||
parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square))
return failure();
if (offsetInfo.size() > 1)
return parser.emitError(offsetLoc) << "expects 0 or 1 offset operand";
return failure(
parser.parseOperand(srcInfo) ||
parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square) ||
parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(srcType) ||

View File

@ -621,7 +621,7 @@ func @view(%arg0 : index, %arg1 : index, %arg2 : index) {
// CHECK: llvm.insertvalue %[[ARG0]], %{{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: llvm.mul %{{.*}}, %[[ARG1]]
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
%1 = view %0[%arg0, %arg1][%arg2]
%1 = view %0[%arg2][%arg0, %arg1]
: memref<2048xi8> to memref<?x?xf32, (d0, d1)[s0, s1] -> (d0 * s0 + d1 + s1)>
// Test two dynamic sizes and static offset.
@ -637,7 +637,7 @@ func @view(%arg0 : index, %arg1 : index, %arg2 : index) {
// CHECK: llvm.insertvalue %arg0, %{{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: llvm.mul %{{.*}}, %[[ARG1]]
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
%2 = view %0[%arg0, %arg1][]
%2 = view %0[][%arg0, %arg1]
: memref<2048xi8> to memref<?x?xf32, (d0, d1)[s0] -> (d0 * s0 + d1)>
// Test one dynamic size and dynamic offset.
@ -653,7 +653,7 @@ func @view(%arg0 : index, %arg1 : index, %arg2 : index) {
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: llvm.mul %{{.*}}, %[[ARG1]]
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
%3 = view %0[%arg1][%arg2]
%3 = view %0[%arg2][%arg1]
: memref<2048xi8> to memref<4x?xf32, (d0, d1)[s0, s1] -> (d0 * s0 + d1 + s1)>
// Test one dynamic size and static offset.
@ -670,7 +670,7 @@ func @view(%arg0 : index, %arg1 : index, %arg2 : index) {
// CHECK: llvm.insertvalue %[[ARG0]], %{{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: llvm.mlir.constant(4 : index) : !llvm.i64
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
%4 = view %0[%arg0][]
%4 = view %0[][%arg0]
: memref<2048xi8> to memref<?x16xf32, (d0, d1) -> (d0 * 4 + d1)>
// Test static sizes and static offset.
@ -703,7 +703,7 @@ func @view(%arg0 : index, %arg1 : index, %arg2 : index) {
// CHECK: llvm.insertvalue %[[ARG0]], %{{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: llvm.mul %[[STRIDE_1]], %[[ARG1]] : !llvm.i64
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
%6 = view %0[%arg0, %arg1][%arg2]
%6 = view %0[%arg2][%arg0, %arg1]
: memref<2048xi8> to memref<?x?xf32, (d0, d1)[s0, s1] -> (d0 * s0 + d1 + s1)>
return

View File

@ -16,9 +16,9 @@
func @matmul(%arg0: memref<?xi8>, %M: index, %N: index, %K: index) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%A = view %arg0[%M, %K][%c0] : memref<?xi8> to memref<?x?xf32, offset: ?, strides: [?, 1]>
%B = view %arg0[%K, %N][%c0] : memref<?xi8> to memref<?x?xf32, offset: ?, strides: [?, 1]>
%C = view %arg0[%M, %N][%c0] : memref<?xi8> to memref<?x?xf32, offset: ?, strides: [?, 1]>
%A = view %arg0[%c0][%M, %K] : memref<?xi8> to memref<?x?xf32, offset: ?, strides: [?, 1]>
%B = view %arg0[%c0][%K, %N] : memref<?xi8> to memref<?x?xf32, offset: ?, strides: [?, 1]>
%C = view %arg0[%c0][%M, %N] : memref<?xi8> to memref<?x?xf32, offset: ?, strides: [?, 1]>
linalg.matmul(%A, %B, %C) : memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>
return
}
@ -42,9 +42,9 @@ func @matmul(%arg0: memref<?xi8>, %M: index, %N: index, %K: index) {
func @matvec(%arg0: memref<?xi8>, %M: index, %N: index) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%2 = view %arg0[%M, %N][%c0] : memref<?xi8> to memref<?x?xf32, offset: ?, strides: [?, 1]>
%3 = view %arg0[%M][%c0] : memref<?xi8> to memref<?xf32, offset: ?, strides: [1]>
%4 = view %arg0[%N][%c0] : memref<?xi8> to memref<?xf32, offset: ?, strides: [1]>
%2 = view %arg0[%c0][%M, %N] : memref<?xi8> to memref<?x?xf32, offset: ?, strides: [?, 1]>
%3 = view %arg0[%c0][%M] : memref<?xi8> to memref<?xf32, offset: ?, strides: [1]>
%4 = view %arg0[%c0][%N] : memref<?xi8> to memref<?xf32, offset: ?, strides: [1]>
linalg.matvec(%2, %3, %4) : memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?xf32, offset: ?, strides: [1]>, memref<?xf32, offset: ?, strides: [1]>
return
}
@ -66,8 +66,8 @@ func @matvec(%arg0: memref<?xi8>, %M: index, %N: index) {
func @dot(%arg0: memref<?xi8>, %M: index) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%1 = view %arg0[%M][%c0] : memref<?xi8> to memref<?xf32, offset: ?, strides: [1]>
%2 = view %arg0[%M][%c0] : memref<?xi8> to memref<?xf32, offset: ?, strides: [1]>
%1 = view %arg0[%c0][%M] : memref<?xi8> to memref<?xf32, offset: ?, strides: [1]>
%2 = view %arg0[%c0][%M] : memref<?xi8> to memref<?xf32, offset: ?, strides: [1]>
%3 = view %arg0[][] : memref<?xi8> to memref<f32>
linalg.dot(%1, %2, %3) : memref<?xf32, offset: ?, strides: [1]>, memref<?xf32, offset: ?, strides: [1]>, memref<f32>
return

View File

@ -17,9 +17,9 @@ module {
%c2 = constant 2 : index
%c0 = constant 0 : index
%c1 = constant 1 : index
%3 = view %A[%M, %K][%c0] : memref<?xi8> to memref<?x?xf32, #map0>
%4 = view %A[%K, %N][%c0] : memref<?xi8> to memref<?x?xf32, #map0>
%5 = view %A[%M, %N][%c0] : memref<?xi8> to memref<?x?xf32, #map0>
%3 = view %A[%c0][%M, %K] : memref<?xi8> to memref<?x?xf32, #map0>
%4 = view %A[%c0][%K, %N] : memref<?xi8> to memref<?x?xf32, #map0>
%5 = view %A[%c0][%M, %N] : memref<?xi8> to memref<?x?xf32, #map0>
%6 = dim %3, 0 : memref<?x?xf32, #map0>
%7 = dim %3, 1 : memref<?x?xf32, #map0>
%8 = dim %4, 1 : memref<?x?xf32, #map0>

View File

@ -25,12 +25,12 @@ func @views(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index
%0 = muli %arg0, %arg0 : index
%1 = alloc (%0) : memref<?xi8>
%2 = linalg.range %arg0:%arg1:%arg2 : !linalg.range
%3 = view %1[%arg0, %arg0][%c0] : memref<?xi8> to memref<?x?xf32, offset: ?, strides: [?, 1]>
%3 = view %1[%c0][%arg0, %arg0] : memref<?xi8> to memref<?x?xf32, offset: ?, strides: [?, 1]>
%4 = linalg.slice %3[%2, %2] : memref<?x?xf32, offset: ?, strides: [?, 1]>, !linalg.range, !linalg.range, memref<?x?xf32, offset: ?, strides: [?, 1]>
%5 = linalg.slice %3[%2, %arg2] : memref<?x?xf32, offset: ?, strides: [?, 1]>, !linalg.range, index, memref<?xf32, offset: ?, strides: [1]>
%6 = linalg.slice %3[%arg2, %2] : memref<?x?xf32, offset: ?, strides: [?, 1]>, index, !linalg.range, memref<?xf32, offset: ?, strides: [1]>
%7 = linalg.slice %3[%arg2, %arg3] : memref<?x?xf32, offset: ?, strides: [?, 1]>, index, index, memref<f32>
%8 = view %1[%arg0, %arg0][%c0] : memref<?xi8> to memref<?x?xvector<4x4xf32>, offset: ?, strides: [?, 1]>
%8 = view %1[%c0][%arg0, %arg0] : memref<?xi8> to memref<?x?xvector<4x4xf32>, offset: ?, strides: [?, 1]>
dealloc %1 : memref<?xi8>
return
}

View File

@ -780,7 +780,7 @@ func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) {
func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) {
%0 = alloc() : memref<2048xi8>
// expected-error@+1 {{incorrect dynamic strides}}
%1 = view %0[%arg0, %arg1][]
%1 = view %0[][%arg0, %arg1]
: memref<2048xi8> to
memref<?x?x4xf32, (d0, d1, d2) -> (d0 * 777 + d1 * 4 + d2)>
return
@ -799,6 +799,17 @@ func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) {
// -----
func @multiple_offsets(%arg0: index) {
%0 = alloc() : memref<2048xi8>
// expected-error@+1 {{expects 0 or 1 offset operand}}
%1 = view %0[%arg0, %arg0][%arg0]
: memref<2048xi8> to
memref<?x?x4xf32, (d0, d1, d2) -> (d0 * 777 + d1 * 4 + d2)>
return
}
// -----
func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
%0 = alloc() : memref<8x16x4xf32, (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2), 2>
// expected-error@+1 {{different memory spaces}}

View File

@ -65,9 +65,9 @@ func @matmul() -> f32 {
%bB = call @alloc_filled_f32(%c160, %f1) : (index, f32) -> (memref<?xi8>)
%bC = call @alloc_filled_f32(%c100, %f10) : (index, f32) -> (memref<?xi8>)
%A = view %bA[%c10, %c16][] : memref<?xi8> to memref<?x?xf32, #strided2D>
%B = view %bB[%c16, %c10][] : memref<?xi8> to memref<?x?xf32, #strided2D>
%C = view %bC[%c10, %c10][] : memref<?xi8> to memref<?x?xf32, #strided2D>
%A = view %bA[][%c10, %c16] : memref<?xi8> to memref<?x?xf32, #strided2D>
%B = view %bB[][%c16, %c10] : memref<?xi8> to memref<?x?xf32, #strided2D>
%C = view %bC[][%c10, %c10] : memref<?xi8> to memref<?x?xf32, #strided2D>
linalg.matmul(%A, %B, %C) : memref<?x?xf32, #strided2D>, memref<?x?xf32, #strided2D>, memref<?x?xf32, #strided2D>
%res = load %C[%c6, %c7] : memref<?x?xf32, #strided2D>