diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td index cf6f36c7098e..38e0cb6c3faf 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -290,16 +290,6 @@ def SubViewOp : Linalg_Op<"subview", [NoSideEffect]>, res.push_back(getRange(i)); return res; } - - // This requires `SubViewOp` to be declared, in the future it should be - // folded into the builders. - static void build(Builder *builder, OperationState &result, Value *view, - ArrayRef ranges) { - result.addOperands(view); - for (auto r : ranges) - result.addOperands({r.min, r.max, r.step}); - result.types.push_back(view->getType()); - } }]; } @@ -352,7 +342,7 @@ def ViewOp : Linalg_Op<"view", [NoSideEffect]>, %1 = linalg.buffer_alloc %0 : !linalg.buffer %2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range - %3 = linalg.view %1[%2, %2] : + %3 = linalg.view %1[%2, %2] : memref, stride_specification> }]; diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index dffc66ed2e1e..24510f05af51 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -313,7 +313,6 @@ static ParseResult parseRangeOp(OpAsmParser &parser, OperationState &result) { //===----------------------------------------------------------------------===// // SliceOp //===----------------------------------------------------------------------===// - void mlir::linalg::SliceOp::build(Builder *b, OperationState &result, Value *base, ArrayRef indexings) { result.addOperands(base); @@ -391,22 +390,32 @@ static LogicalResult verify(SliceOp op) { //===----------------------------------------------------------------------===// // SubViewOp //===----------------------------------------------------------------------===// +static Type getSubViewResultType(MemRefType memRefType) { + auto rank = memRefType.getRank(); + SmallVector sizes(rank, -1); + int64_t offset; + SmallVector strides; + Type elementType = memRefType.getElementType(); + auto res = getStridesAndOffset(memRefType, strides, offset); + assert(succeeded(res) && "SubViewOp expected strided memref type"); + (void)res; + // Assume sizes and offset are fully dynamic for now until canonicalization + // occurs on the ranges. + // Strides don't change though. + // TODO(ntv) for canonicalization it may be better to use a (min, size, step) + // instead of a (min, max, step) abstraction. + auto stridedLayout = makeStridedLinearLayoutMap( + strides, MemRefType::getDynamicStrideOrOffset(), memRefType.getContext()); + return MemRefType::get(sizes, elementType, stridedLayout, + memRefType.getMemorySpace()); +} + void mlir::linalg::SubViewOp::build(Builder *b, OperationState &result, Value *view, ArrayRef ranges, Type resultType, ArrayRef attrs) { - // If the result type is not specified, assume sizes are fully dynamic. - // Strides don't change though. - // TODO(ntv) for canonicalization it may be better to use a (min, size, step) - // instead of a (min, max, step) abstraction. - if (!resultType) { - auto rank = ranges.size(); - SmallVector sizes(rank, -1); - auto memRefType = view->getType().cast(); - Type elementType = memRefType.getElementType(); - resultType = MemRefType::get(sizes, elementType, memRefType.getAffineMaps(), - memRefType.getMemorySpace()); - } + if (!resultType) + resultType = getSubViewResultType(view->getType().cast()); build(b, result, resultType, view, ranges); result.addAttributes(attrs); } @@ -442,7 +451,7 @@ static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) { return failure( parser.resolveOperand(inputView, memRefType, result.operands) || parser.resolveOperands(ops, indexTy, result.operands) || - parser.addTypeToList(memRefType, result.types)); + parser.addTypeToList(getSubViewResultType(memRefType), result.types)); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp index 0e7463144b9a..c54dffe53ae9 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -101,7 +101,14 @@ static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op, << "loopPos: " << loopPos << "\t" << viewRanges[d]); } // TODO(ntv): opportunities for folding/CSE here rather than build new IR. - clonedViews.push_back(b.create(loc, view, viewRanges)); + SmallVector subViewOperands; + subViewOperands.reserve(viewRanges.size() * 3); + for (auto r : viewRanges) { + subViewOperands.push_back(r.min); + subViewOperands.push_back(r.max); + subViewOperands.push_back(r.step); + } + clonedViews.push_back(b.create(loc, view, subViewOperands)); } auto operands = getAssumedNonViewOperands(op); clonedViews.append(operands.begin(), operands.end()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index aca94775622a..064676f44951 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -186,13 +186,13 @@ makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp, } // Construct a new subview for the tile. - SmallVector subViewOperands; - subViewOperands.reserve(rank * 3); + SmallVector subViewRangeOperands; + subViewRangeOperands.reserve(rank * 3); for (unsigned r = 0; r < rank; ++r) { if (!isTiled(map.getSubMap({r}), tileSizes)) { - subViewOperands.push_back(SubViewOp::Range{constant_index(folder, 0), - dim(view, r), - constant_index(folder, 1)}); + subViewRangeOperands.push_back( + SubViewOp::Range{constant_index(folder, 0), dim(view, r), + constant_index(folder, 1)}); continue; } @@ -201,9 +201,16 @@ makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp, auto *max = applyMapToValues(b, loc, m, maxes, folder).front(); // Tiling creates a new slice at the proper index, the slice step is 1 // (i.e. the slice view does not subsample, stepping occurs in the loop). - subViewOperands.push_back( + subViewRangeOperands.push_back( SubViewOp::Range{min, max, constant_index(folder, 1)}); } + SmallVector subViewOperands; + subViewOperands.reserve(subViewRangeOperands.size() * 3); + for (auto r : subViewRangeOperands) { + subViewOperands.push_back(r.min); + subViewOperands.push_back(r.max); + subViewOperands.push_back(r.step); + } res.push_back(b.create(loc, view, subViewOperands)); } diff --git a/mlir/test/Dialect/Linalg/fusion.mlir b/mlir/test/Dialect/Linalg/fusion.mlir index c2ef8ebc3b2e..5afb991cdc3f 100644 --- a/mlir/test/Dialect/Linalg/fusion.mlir +++ b/mlir/test/Dialect/Linalg/fusion.mlir @@ -6,30 +6,30 @@ // CHECK-DAG: #[[strided2D:.*]] = (d0, d1)[s0, s1] -> (d0 * s0 + d1 * s1) -func @f1(%A: memref, %B: memref, %C: memref, %D: memref, %E: memref) -> memref { +func @f1(%A: memref, %B: memref, %C: memref, %D: memref, %E: memref) -> memref { %c0 = constant 0 : index %c4 = constant 4 : index %c3 = constant 3 : index %c2 = constant 2 : index - %0 = dim %A, 0 : memref - %1 = dim %A, 1 : memref - %2 = dim %B, 1 : memref - linalg.matmul(%A, %B, %C) : memref, memref, memref + %0 = dim %A, 0 : memref + %1 = dim %A, 1 : memref + %2 = dim %B, 1 : memref + linalg.matmul(%A, %B, %C) : memref, memref, memref %c1 = constant 1 : index loop.for %arg5 = %c0 to %0 step %c2 { loop.for %arg6 = %c0 to %2 step %c3 { loop.for %arg7 = %c0 to %1 step %c4 { %3 = affine.apply #map0(%arg5) %4 = affine.apply #map1(%arg7) - %5 = linalg.subview %A[%arg5, %3, %c1, %arg7, %4, %c1] : memref + %5 = linalg.subview %A[%arg5, %3, %c1, %arg7, %4, %c1] : memref %6 = affine.apply #map2(%arg6) - %7 = linalg.subview %B[%arg7, %4, %c1, %arg6, %6, %c1] : memref - %8 = linalg.subview %C[%arg5, %3, %c1, %arg6, %6, %c1] : memref - linalg.matmul(%5, %7, %8) : memref, memref, memref + %7 = linalg.subview %B[%arg7, %4, %c1, %arg6, %6, %c1] : memref + %8 = linalg.subview %C[%arg5, %3, %c1, %arg6, %6, %c1] : memref + linalg.matmul(%5, %7, %8) : memref, memref, memref } } } - return %E : memref + return %E : memref } // CHECK-LABEL: func @f1 // CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}}) @@ -59,7 +59,7 @@ func @f2(%A: memref, %B: memref %8 = linalg.subview %E[%arg5, %3, %c1, %arg6, %6, %c1] : memref - linalg.matmul(%5, %7, %8) : memref, memref, memref + linalg.matmul(%5, %7, %8) : memref, memref, memref } } } @@ -95,7 +95,7 @@ func @f3(%A: memref, %B: memref %8 = linalg.subview %E[%arg5, %3, %c1, %arg6, %6, %c1] : memref - linalg.matmul(%5, %7, %8) : memref, memref, memref + linalg.matmul(%5, %7, %8) : memref, memref, memref } } } @@ -132,7 +132,7 @@ func @f4(%A: memref, %B: memref %8 = linalg.subview %E[%arg5, %3, %c1, %arg6, %6, %c1] : memref - linalg.matmul(%5, %7, %8) : memref, memref, memref + linalg.matmul(%5, %7, %8) : memref, memref, memref } } } @@ -171,7 +171,7 @@ func @f5(%A: memref, %B: memref %8 = linalg.subview %E[%arg5, %3, %c1, %arg6, %6, %c1] : memref - linalg.matmul(%5, %7, %8) : memref, memref, memref + linalg.matmul(%5, %7, %8) : memref, memref, memref } } } @@ -210,7 +210,7 @@ func @f6(%A: memref, %B: memref %8 = linalg.subview %E[%arg5, %3, %c1, %arg6, %6, %c1] : memref - linalg.matmul(%5, %7, %8) : memref, memref, memref + linalg.matmul(%5, %7, %8) : memref, memref, memref } } } @@ -250,7 +250,7 @@ func @f7(%A: memref, %B: memref %10 = linalg.subview %E[%arg5, %5, %c1, %arg6, %8, %c1] : memref - linalg.matmul(%7, %9, %10) : memref, memref, memref + linalg.matmul(%7, %9, %10) : memref, memref, memref } } } @@ -263,7 +263,7 @@ func @f7(%A: memref, %B: memref %10 = linalg.subview %E[%arg5, %5, %c1, %arg6, %8, %c1] : memref - linalg.matmul(%7, %9, %10) : memref, memref, memref + linalg.matmul(%7, %9, %10) : memref, memref, memref } } } @@ -308,7 +308,7 @@ func @f8(%A: memref, %B: memref %8 = linalg.subview %E[%arg5, %3, %c1, %arg6, %6, %c1] : memref - linalg.matmul(%5, %7, %8) : memref, memref, memref + linalg.matmul(%5, %7, %8) : memref, memref, memref } } } @@ -353,7 +353,7 @@ func @pointwise(%A: memref, %B: memref, memref, memref + }: memref, memref, memref } } return @@ -396,7 +396,7 @@ func @pointwise_no_view(%M: index, %N: index) { ^bb0(%arg6: f32, %arg7: f32, %arg8: f32): // no predecessors %7 = mulf %arg6, %arg7 : f32 linalg.yield %7 : f32 - }: memref, memref, memref + }: memref, memref, memref } } return diff --git a/mlir/test/Dialect/Linalg/tile.mlir b/mlir/test/Dialect/Linalg/tile.mlir index fd609ec2bdf8..458cd0b4dc25 100644 --- a/mlir/test/Dialect/Linalg/tile.mlir +++ b/mlir/test/Dialect/Linalg/tile.mlir @@ -20,6 +20,10 @@ // TILE-002-DAG: #[[strided2D:.*]] = (d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1) // TILE-234-DAG: #[[strided2D:.*]] = (d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1) +// TILE-2-DAG: #[[stride_99_1_layout_map:.*]] = (d0, d1)[s0] -> (d0 * 99 + s0 + d1) +// TILE-02-DAG: #[[stride_99_1_layout_map:.*]] = (d0, d1)[s0] -> (d0 * 99 + s0 + d1) +// TILE-234-DAG: #[[stride_99_1_layout_map:.*]] = (d0, d1)[s0] -> (d0 * 99 + s0 + d1) + func @matmul(%arg0: memref, %arg1: memref, %arg2: memref) { linalg.matmul(%arg0, %arg1, %arg2) : memref, memref, memref return @@ -146,6 +150,34 @@ func @dot(%arg0: memref, %arg1: memref // TILE-234: linalg.dot(%[[sAi]], %[[sBi]], %{{.*}}) : memref, memref, memref +func @fill_static(%arg0: memref<127x99xf32>, %arg1: f32) { + linalg.fill(%arg0, %arg1) : memref<127x99xf32>, f32 + return +} +// TILE-2-LABEL: func @fill_static +// TILE-2: for +// TILE-2-NOT: for +// TILE-2: linalg.subview{{.*}} : memref<127x99xf32> +// TILE-2: linalg.fill{{.*}} : memref, f32 + +// TILE-02-LABEL: func @fill_static +// TILE-02: for +// TILE-02-NOT: for +// TILE-02: linalg.subview{{.*}} : memref<127x99xf32> +// TILE-02: linalg.fill{{.*}} : memref, f32 + +// TILE-002-LABEL: func @fill_static +// TILE-002-NOT: for +// TILE-002: linalg.fill{{.*}} memref<127x99xf32>, f32 + +// TILE-234-LABEL: func @fill_static +// TILE-234: for +// TILE-234: for +// TILE-234-NOT: for +// TILE-234: linalg.subview{{.*}} : memref<127x99xf32> +// TILE-234: linalg.fill{{.*}} : memref, f32 + + func @fill(%arg0: memref, %arg1: f32) { linalg.fill(%arg0, %arg1) : memref, f32 return