Fix linalg.subview behavior in (partially) static cases.

When the implementation of the strided memref [RFC](https://groups.google.com/a/tensorflow.org/forum/#!msg/mlir/MaL8m2nXuio/1scRqZa6AQAJ) landed, linalg started using this type instead of the now retired !linalg.view.

As static and partially static cases appear, the stride information needs to be maintained properly. In particular, the result type of the subview op was generally incorrect.

This CL fixes the issue by computing a return type that:
1. always has dynamic sizes, which is generally the only correct way to construct a subview in the absence of data padding and/or code versioning.
2. has the same strides as the base strided memref.

Point 1. above can be further refined but will needs further analysis and canonicalization to optimize the particular case where:
1. The base memref has static size along a given dimension.
2. The subview size can be statically derived (e.g. after canonicalization).
3. *And* the subview size is an even divisor of the base memref.

This 3rd constraint is well-known in the case of tiled layouts that don't assume implicit padding: the boundary tile may be only partial and has size given by `problem_size % tile_size`.

Tests are updated as appropriate.

PiperOrigin-RevId: 274578624
This commit is contained in:
Nicolas Vasilache 2019-10-14 07:58:54 -07:00 committed by A. Unique TensorFlower
parent c2285b619d
commit 5c5d83afb4
6 changed files with 97 additions and 52 deletions

View File

@ -290,16 +290,6 @@ def SubViewOp : Linalg_Op<"subview", [NoSideEffect]>,
res.push_back(getRange(i));
return res;
}
// This requires `SubViewOp` to be declared, in the future it should be
// folded into the builders.
static void build(Builder *builder, OperationState &result, Value *view,
ArrayRef<SubViewOp::Range> ranges) {
result.addOperands(view);
for (auto r : ranges)
result.addOperands({r.min, r.max, r.step});
result.types.push_back(view->getType());
}
}];
}
@ -352,7 +342,7 @@ def ViewOp : Linalg_Op<"view", [NoSideEffect]>,
%1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
%2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range
%3 = linalg.view %1[%2, %2] :
%3 = linalg.view %1[%2, %2] :
memref<?x?xvector<4xf32>, stride_specification>
}];

View File

@ -313,7 +313,6 @@ static ParseResult parseRangeOp(OpAsmParser &parser, OperationState &result) {
//===----------------------------------------------------------------------===//
// SliceOp
//===----------------------------------------------------------------------===//
void mlir::linalg::SliceOp::build(Builder *b, OperationState &result,
Value *base, ArrayRef<Value *> indexings) {
result.addOperands(base);
@ -391,22 +390,32 @@ static LogicalResult verify(SliceOp op) {
//===----------------------------------------------------------------------===//
// SubViewOp
//===----------------------------------------------------------------------===//
static Type getSubViewResultType(MemRefType memRefType) {
auto rank = memRefType.getRank();
SmallVector<int64_t, 4> sizes(rank, -1);
int64_t offset;
SmallVector<int64_t, 4> strides;
Type elementType = memRefType.getElementType();
auto res = getStridesAndOffset(memRefType, strides, offset);
assert(succeeded(res) && "SubViewOp expected strided memref type");
(void)res;
// Assume sizes and offset are fully dynamic for now until canonicalization
// occurs on the ranges.
// Strides don't change though.
// TODO(ntv) for canonicalization it may be better to use a (min, size, step)
// instead of a (min, max, step) abstraction.
auto stridedLayout = makeStridedLinearLayoutMap(
strides, MemRefType::getDynamicStrideOrOffset(), memRefType.getContext());
return MemRefType::get(sizes, elementType, stridedLayout,
memRefType.getMemorySpace());
}
void mlir::linalg::SubViewOp::build(Builder *b, OperationState &result,
Value *view, ArrayRef<Value *> ranges,
Type resultType,
ArrayRef<NamedAttribute> attrs) {
// If the result type is not specified, assume sizes are fully dynamic.
// Strides don't change though.
// TODO(ntv) for canonicalization it may be better to use a (min, size, step)
// instead of a (min, max, step) abstraction.
if (!resultType) {
auto rank = ranges.size();
SmallVector<int64_t, 4> sizes(rank, -1);
auto memRefType = view->getType().cast<MemRefType>();
Type elementType = memRefType.getElementType();
resultType = MemRefType::get(sizes, elementType, memRefType.getAffineMaps(),
memRefType.getMemorySpace());
}
if (!resultType)
resultType = getSubViewResultType(view->getType().cast<MemRefType>());
build(b, result, resultType, view, ranges);
result.addAttributes(attrs);
}
@ -442,7 +451,7 @@ static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) {
return failure(
parser.resolveOperand(inputView, memRefType, result.operands) ||
parser.resolveOperands(ops, indexTy, result.operands) ||
parser.addTypeToList(memRefType, result.types));
parser.addTypeToList(getSubViewResultType(memRefType), result.types));
}
//===----------------------------------------------------------------------===//

View File

@ -101,7 +101,14 @@ static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op,
<< "loopPos: " << loopPos << "\t" << viewRanges[d]);
}
// TODO(ntv): opportunities for folding/CSE here rather than build new IR.
clonedViews.push_back(b.create<SubViewOp>(loc, view, viewRanges));
SmallVector<Value *, 12> subViewOperands;
subViewOperands.reserve(viewRanges.size() * 3);
for (auto r : viewRanges) {
subViewOperands.push_back(r.min);
subViewOperands.push_back(r.max);
subViewOperands.push_back(r.step);
}
clonedViews.push_back(b.create<SubViewOp>(loc, view, subViewOperands));
}
auto operands = getAssumedNonViewOperands(op);
clonedViews.append(operands.begin(), operands.end());

View File

@ -186,13 +186,13 @@ makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp,
}
// Construct a new subview for the tile.
SmallVector<SubViewOp::Range, 4> subViewOperands;
subViewOperands.reserve(rank * 3);
SmallVector<SubViewOp::Range, 4> subViewRangeOperands;
subViewRangeOperands.reserve(rank * 3);
for (unsigned r = 0; r < rank; ++r) {
if (!isTiled(map.getSubMap({r}), tileSizes)) {
subViewOperands.push_back(SubViewOp::Range{constant_index(folder, 0),
dim(view, r),
constant_index(folder, 1)});
subViewRangeOperands.push_back(
SubViewOp::Range{constant_index(folder, 0), dim(view, r),
constant_index(folder, 1)});
continue;
}
@ -201,9 +201,16 @@ makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp,
auto *max = applyMapToValues(b, loc, m, maxes, folder).front();
// Tiling creates a new slice at the proper index, the slice step is 1
// (i.e. the slice view does not subsample, stepping occurs in the loop).
subViewOperands.push_back(
subViewRangeOperands.push_back(
SubViewOp::Range{min, max, constant_index(folder, 1)});
}
SmallVector<Value *, 12> subViewOperands;
subViewOperands.reserve(subViewRangeOperands.size() * 3);
for (auto r : subViewRangeOperands) {
subViewOperands.push_back(r.min);
subViewOperands.push_back(r.max);
subViewOperands.push_back(r.step);
}
res.push_back(b.create<SubViewOp>(loc, view, subViewOperands));
}

View File

@ -6,30 +6,30 @@
// CHECK-DAG: #[[strided2D:.*]] = (d0, d1)[s0, s1] -> (d0 * s0 + d1 * s1)
func @f1(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>, %B: memref<?x?xf32, offset: 0, strides: [?, ?]>, %C: memref<?x?xf32, offset: 0, strides: [?, ?]>, %D: memref<?x?xf32, offset: 0, strides: [?, ?]>, %E: memref<?x?xf32, offset: 0, strides: [?, ?]>) -> memref<?x?xf32, offset: 0, strides: [?, ?]> {
func @f1(%A: memref<?x?xf32, offset: 0, strides: [?, 1]>, %B: memref<?x?xf32, offset: 0, strides: [?, 1]>, %C: memref<?x?xf32, offset: 0, strides: [?, 1]>, %D: memref<?x?xf32, offset: 0, strides: [?, 1]>, %E: memref<?x?xf32, offset: 0, strides: [?, 1]>) -> memref<?x?xf32, offset: 0, strides: [?, 1]> {
%c0 = constant 0 : index
%c4 = constant 4 : index
%c3 = constant 3 : index
%c2 = constant 2 : index
%0 = dim %A, 0 : memref<?x?xf32, offset: 0, strides: [?, ?]>
%1 = dim %A, 1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
%2 = dim %B, 1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
linalg.matmul(%A, %B, %C) : memref<?x?xf32, offset: 0, strides: [?, ?]>, memref<?x?xf32, offset: 0, strides: [?, ?]>, memref<?x?xf32, offset: 0, strides: [?, ?]>
%0 = dim %A, 0 : memref<?x?xf32, offset: 0, strides: [?, 1]>
%1 = dim %A, 1 : memref<?x?xf32, offset: 0, strides: [?, 1]>
%2 = dim %B, 1 : memref<?x?xf32, offset: 0, strides: [?, 1]>
linalg.matmul(%A, %B, %C) : memref<?x?xf32, offset: 0, strides: [?, 1]>, memref<?x?xf32, offset: 0, strides: [?, 1]>, memref<?x?xf32, offset: 0, strides: [?, 1]>
%c1 = constant 1 : index
loop.for %arg5 = %c0 to %0 step %c2 {
loop.for %arg6 = %c0 to %2 step %c3 {
loop.for %arg7 = %c0 to %1 step %c4 {
%3 = affine.apply #map0(%arg5)
%4 = affine.apply #map1(%arg7)
%5 = linalg.subview %A[%arg5, %3, %c1, %arg7, %4, %c1] : memref<?x?xf32, offset: 0, strides: [?, ?]>
%5 = linalg.subview %A[%arg5, %3, %c1, %arg7, %4, %c1] : memref<?x?xf32, offset: 0, strides: [?, 1]>
%6 = affine.apply #map2(%arg6)
%7 = linalg.subview %B[%arg7, %4, %c1, %arg6, %6, %c1] : memref<?x?xf32, offset: 0, strides: [?, ?]>
%8 = linalg.subview %C[%arg5, %3, %c1, %arg6, %6, %c1] : memref<?x?xf32, offset: 0, strides: [?, ?]>
linalg.matmul(%5, %7, %8) : memref<?x?xf32, offset: 0, strides: [?, ?]>, memref<?x?xf32, offset: 0, strides: [?, ?]>, memref<?x?xf32, offset: 0, strides: [?, ?]>
%7 = linalg.subview %B[%arg7, %4, %c1, %arg6, %6, %c1] : memref<?x?xf32, offset: 0, strides: [?, 1]>
%8 = linalg.subview %C[%arg5, %3, %c1, %arg6, %6, %c1] : memref<?x?xf32, offset: 0, strides: [?, 1]>
linalg.matmul(%5, %7, %8) : memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>
}
}
}
return %E : memref<?x?xf32, offset: 0, strides: [?, ?]>
return %E : memref<?x?xf32, offset: 0, strides: [?, 1]>
}
// CHECK-LABEL: func @f1
// CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}})
@ -59,7 +59,7 @@ func @f2(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>, %B: memref<?x?xf32, of
%6 = affine.apply #map2(%arg6)
%7 = linalg.subview %D[%arg7, %4, %c1, %arg6, %6, %c1] : memref<?x?xf32, offset: 0, strides: [?, ?]>
%8 = linalg.subview %E[%arg5, %3, %c1, %arg6, %6, %c1] : memref<?x?xf32, offset: 0, strides: [?, ?]>
linalg.matmul(%5, %7, %8) : memref<?x?xf32, offset: 0, strides: [?, ?]>, memref<?x?xf32, offset: 0, strides: [?, ?]>, memref<?x?xf32, offset: 0, strides: [?, ?]>
linalg.matmul(%5, %7, %8) : memref<?x?xf32, offset: ?, strides: [?, ?]>, memref<?x?xf32, offset: ?, strides: [?, ?]>, memref<?x?xf32, offset: ?, strides: [?, ?]>
}
}
}
@ -95,7 +95,7 @@ func @f3(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>, %B: memref<?x?xf32, of
%6 = affine.apply #map2(%arg6)
%7 = linalg.subview %C[%arg7, %4, %c1, %arg6, %6, %c1] : memref<?x?xf32, offset: 0, strides: [?, ?]>
%8 = linalg.subview %E[%arg5, %3, %c1, %arg6, %6, %c1] : memref<?x?xf32, offset: 0, strides: [?, ?]>
linalg.matmul(%5, %7, %8) : memref<?x?xf32, offset: 0, strides: [?, ?]>, memref<?x?xf32, offset: 0, strides: [?, ?]>, memref<?x?xf32, offset: 0, strides: [?, ?]>
linalg.matmul(%5, %7, %8) : memref<?x?xf32, offset: ?, strides: [?, ?]>, memref<?x?xf32, offset: ?, strides: [?, ?]>, memref<?x?xf32, offset: ?, strides: [?, ?]>
}
}
}
@ -132,7 +132,7 @@ func @f4(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>, %B: memref<?x?xf32, of
%6 = affine.apply #map2(%arg6)
%7 = linalg.subview %D[%arg7, %4, %c1, %arg6, %6, %c1] : memref<?x?xf32, offset: 0, strides: [?, ?]>
%8 = linalg.subview %E[%arg5, %3, %c1, %arg6, %6, %c1] : memref<?x?xf32, offset: 0, strides: [?, ?]>
linalg.matmul(%5, %7, %8) : memref<?x?xf32, offset: 0, strides: [?, ?]>, memref<?x?xf32, offset: 0, strides: [?, ?]>, memref<?x?xf32, offset: 0, strides: [?, ?]>
linalg.matmul(%5, %7, %8) : memref<?x?xf32, offset: ?, strides: [?, ?]>, memref<?x?xf32, offset: ?, strides: [?, ?]>, memref<?x?xf32, offset: ?, strides: [?, ?]>
}
}
}
@ -171,7 +171,7 @@ func @f5(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>, %B: memref<?x?xf32, of
%6 = affine.apply #map2(%arg6)
%7 = linalg.subview %B[%arg7, %4, %c1, %arg6, %6, %c1] : memref<?x?xf32, offset: 0, strides: [?, ?]>
%8 = linalg.subview %E[%arg5, %3, %c1, %arg6, %6, %c1] : memref<?x?xf32, offset: 0, strides: [?, ?]>
linalg.matmul(%5, %7, %8) : memref<?x?xf32, offset: 0, strides: [?, ?]>, memref<?x?xf32, offset: 0, strides: [?, ?]>, memref<?x?xf32, offset: 0, strides: [?, ?]>
linalg.matmul(%5, %7, %8) : memref<?x?xf32, offset: ?, strides: [?, ?]>, memref<?x?xf32, offset: ?, strides: [?, ?]>, memref<?x?xf32, offset: ?, strides: [?, ?]>
}
}
}
@ -210,7 +210,7 @@ func @f6(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>, %B: memref<?x?xf32, of
%6 = affine.apply #map2(%arg6)
%7 = linalg.subview %D[%arg7, %4, %c1, %arg6, %6, %c1] : memref<?x?xf32, offset: 0, strides: [?, ?]>
%8 = linalg.subview %E[%arg5, %3, %c1, %arg6, %6, %c1] : memref<?x?xf32, offset: 0, strides: [?, ?]>
linalg.matmul(%5, %7, %8) : memref<?x?xf32, offset: 0, strides: [?, ?]>, memref<?x?xf32, offset: 0, strides: [?, ?]>, memref<?x?xf32, offset: 0, strides: [?, ?]>
linalg.matmul(%5, %7, %8) : memref<?x?xf32, offset: ?, strides: [?, ?]>, memref<?x?xf32, offset: ?, strides: [?, ?]>, memref<?x?xf32, offset: ?, strides: [?, ?]>
}
}
}
@ -250,7 +250,7 @@ func @f7(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>, %B: memref<?x?xf32, of
%8 = affine.apply #map2(%arg6)
%9 = linalg.subview %C[%arg7, %6, %c1, %arg6, %8, %c1] : memref<?x?xf32, offset: 0, strides: [?, ?]>
%10 = linalg.subview %E[%arg5, %5, %c1, %arg6, %8, %c1] : memref<?x?xf32, offset: 0, strides: [?, ?]>
linalg.matmul(%7, %9, %10) : memref<?x?xf32, offset: 0, strides: [?, ?]>, memref<?x?xf32, offset: 0, strides: [?, ?]>, memref<?x?xf32, offset: 0, strides: [?, ?]>
linalg.matmul(%7, %9, %10) : memref<?x?xf32, offset: ?, strides: [?, ?]>, memref<?x?xf32, offset: ?, strides: [?, ?]>, memref<?x?xf32, offset: ?, strides: [?, ?]>
}
}
}
@ -263,7 +263,7 @@ func @f7(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>, %B: memref<?x?xf32, of
%8 = affine.apply #map2(%arg6)
%9 = linalg.subview %D[%arg7, %6, %c1, %arg6, %8, %c1] : memref<?x?xf32, offset: 0, strides: [?, ?]>
%10 = linalg.subview %E[%arg5, %5, %c1, %arg6, %8, %c1] : memref<?x?xf32, offset: 0, strides: [?, ?]>
linalg.matmul(%7, %9, %10) : memref<?x?xf32, offset: 0, strides: [?, ?]>, memref<?x?xf32, offset: 0, strides: [?, ?]>, memref<?x?xf32, offset: 0, strides: [?, ?]>
linalg.matmul(%7, %9, %10) : memref<?x?xf32, offset: ?, strides: [?, ?]>, memref<?x?xf32, offset: ?, strides: [?, ?]>, memref<?x?xf32, offset: ?, strides: [?, ?]>
}
}
}
@ -308,7 +308,7 @@ func @f8(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>, %B: memref<?x?xf32, of
%6 = affine.apply #map2(%arg6)
%7 = linalg.subview %D[%arg7, %4, %c1, %arg6, %6, %c1] : memref<?x?xf32, offset: 0, strides: [?, ?]>
%8 = linalg.subview %E[%arg5, %3, %c1, %arg6, %6, %c1] : memref<?x?xf32, offset: 0, strides: [?, ?]>
linalg.matmul(%5, %7, %8) : memref<?x?xf32, offset: 0, strides: [?, ?]>, memref<?x?xf32, offset: 0, strides: [?, ?]>, memref<?x?xf32, offset: 0, strides: [?, ?]>
linalg.matmul(%5, %7, %8) : memref<?x?xf32, offset: ?, strides: [?, ?]>, memref<?x?xf32, offset: ?, strides: [?, ?]>, memref<?x?xf32, offset: ?, strides: [?, ?]>
}
}
}
@ -353,7 +353,7 @@ func @pointwise(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>, %B: memref<?x?x
^bb0(%arg6: f32, %arg7: f32, %arg8: f32): // no predecessors
%7 = mulf %arg6, %arg7 : f32
linalg.yield %7 : f32
}: memref<?x?xf32, offset: 0, strides: [?, ?]>, memref<?x?xf32, offset: 0, strides: [?, ?]>, memref<?x?xf32, offset: 0, strides: [?, ?]>
}: memref<?x?xf32, offset: ?, strides: [?, ?]>, memref<?x?xf32, offset: ?, strides: [?, ?]>, memref<?x?xf32, offset: ?, strides: [?, ?]>
}
}
return
@ -396,7 +396,7 @@ func @pointwise_no_view(%M: index, %N: index) {
^bb0(%arg6: f32, %arg7: f32, %arg8: f32): // no predecessors
%7 = mulf %arg6, %arg7 : f32
linalg.yield %7 : f32
}: memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
}: memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>
}
}
return

View File

@ -20,6 +20,10 @@
// TILE-002-DAG: #[[strided2D:.*]] = (d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)
// TILE-234-DAG: #[[strided2D:.*]] = (d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)
// TILE-2-DAG: #[[stride_99_1_layout_map:.*]] = (d0, d1)[s0] -> (d0 * 99 + s0 + d1)
// TILE-02-DAG: #[[stride_99_1_layout_map:.*]] = (d0, d1)[s0] -> (d0 * 99 + s0 + d1)
// TILE-234-DAG: #[[stride_99_1_layout_map:.*]] = (d0, d1)[s0] -> (d0 * 99 + s0 + d1)
func @matmul(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg2: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
linalg.matmul(%arg0, %arg1, %arg2) : memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>
return
@ -146,6 +150,34 @@ func @dot(%arg0: memref<?xf32, offset: ?, strides: [1]>, %arg1: memref<?xf32, of
// TILE-234: %[[sBi:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[b]], %{{.*}}] : memref<?xf32, #[[strided1D]]>
// TILE-234: linalg.dot(%[[sAi]], %[[sBi]], %{{.*}}) : memref<?xf32, #[[strided1D]]>, memref<?xf32, #[[strided1D]]>, memref<f32>
func @fill_static(%arg0: memref<127x99xf32>, %arg1: f32) {
linalg.fill(%arg0, %arg1) : memref<127x99xf32>, f32
return
}
// TILE-2-LABEL: func @fill_static
// TILE-2: for
// TILE-2-NOT: for
// TILE-2: linalg.subview{{.*}} : memref<127x99xf32>
// TILE-2: linalg.fill{{.*}} : memref<?x?xf32, #[[stride_99_1_layout_map]]>, f32
// TILE-02-LABEL: func @fill_static
// TILE-02: for
// TILE-02-NOT: for
// TILE-02: linalg.subview{{.*}} : memref<127x99xf32>
// TILE-02: linalg.fill{{.*}} : memref<?x?xf32, #[[stride_99_1_layout_map]]>, f32
// TILE-002-LABEL: func @fill_static
// TILE-002-NOT: for
// TILE-002: linalg.fill{{.*}} memref<127x99xf32>, f32
// TILE-234-LABEL: func @fill_static
// TILE-234: for
// TILE-234: for
// TILE-234-NOT: for
// TILE-234: linalg.subview{{.*}} : memref<127x99xf32>
// TILE-234: linalg.fill{{.*}} : memref<?x?xf32, #[[stride_99_1_layout_map]]>, f32
func @fill(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: f32) {
linalg.fill(%arg0, %arg1) : memref<?x?xf32, offset: ?, strides: [?, 1]>, f32
return