From 4e4a4c057665d169b8fe6dcdd5bb7c0b0bf8ff19 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Thu, 7 Jul 2022 15:55:44 +0200 Subject: [PATCH] [mlir] Allow Tile transform op to take dynamic sizes Extend the definition of the Tile structured transform op to enable it accepting handles to operations that produce tile sizes at runtime. This is useful by itself and prepares for more advanced tiling strategies. Note that the changes are relevant only to the transform dialect, the tiling transformation itself already supports dynamic sizes. Depends On D129216 Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D129217 --- .../Linalg/TransformOps/LinalgTransformOps.td | 51 ++++-- .../include/mlir/Dialect/Linalg/Utils/Utils.h | 5 + .../TransformOps/LinalgTransformOps.cpp | 163 ++++++++++++++---- mlir/lib/Dialect/Linalg/Transforms/Split.cpp | 10 -- mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 8 + .../dialects/_structured_transform_ops_ext.py | 28 ++- .../Linalg/transform-op-scalarize.mlir | 2 +- .../Dialect/Linalg/transform-op-tile.mlir | 59 ++++++- mlir/test/Dialect/Linalg/transform-ops.mlir | 2 +- .../Transform/selective-targeting.mlir | 2 +- .../dialects/transform_structured_ext.py | 31 +++- .../llvm-project-overlay/mlir/BUILD.bazel | 1 + 12 files changed, 296 insertions(+), 66 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 003d2eb3a544..021158f873b0 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -396,28 +396,59 @@ def SplitReductionOp : Op, - FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface]> { + DeclareOpInterfaceMethods]> { let description = [{ - Indicates that the given `target` op should be tiled with the options - provided as attributes. This transform generates a loop nest with a smaller - ("tiled") target operation in its body. Currently limited to LinalgOps. + Indicates that the given `target` op should be tiled with the given sizes. + This transform generates a loop nest with a smaller ("tiled") target + operation in its body. Currently limited to LinalgOps. - `sizes` are the tile sizes. A tile size of `0` indicates that the - respective dimension should not be tiled. No loop will be generated for such - dimensions. If all tile sizes are `0`, this transform is effectively a - no-op. + Tile sizes may be known at transformation time, in which case they are + expected to be provided in the `static_size` attribute, or not, in which + case the tile value must be computed by the payload IR and the handle to the + operation computing it must be provided through `dynamic_sizes`. When the + sizes are not known statically, the corresponding entry in the + `static_sizes` attribute must be set to `ShapedType::kDynamicSize`. Only + the dynamic sizes must be provided in `dynamic_sizes`, i.e., there should + be as many handles as `ShapedType::kDynamicSize` values in the + `static_sizes` attribute. A static size of `0` indicates that the dimension + should not be tiled. No loop will be generated for such dimensions. If all + tile sizes are `0`, this transform is effectively a no-op. This op returns handles to the tiled op (in the generated loop nest) and the - generated loops. The number of loops is the number of non-zero tile sizes. + generated loops. The number of loops is the number of tile sizes that are + statically known to be non-zero. + + #### Return modes + + On success, the resulting handles are associated with co-indexed lists of + tiled operations and loops around them. + + This operation only supports Linalg ops and produces a silenceable failure + if the input contains any non-Linalg ops. The ops preceding it in the list + associated with the `target` handle will have been tiled. + + This operation produces a silenceable failure if the `dynamic_sizes` handles + are associated with lists of payload operations of a size different than + that of the list associated with the `target` handle. + + If the internal implementation of tiling for any of the operations fails, + produces a definite failure. }]; let arguments = (ins PDL_Operation:$target, - DefaultValuedAttr:$sizes, + Variadic:$dynamic_sizes, + DefaultValuedAttr:$static_sizes, DefaultValuedAttr:$interchange); let results = (outs PDL_Operation:$tiled_linalg_op, Variadic:$loops); let hasCustomAssemblyFormat = 1; + + let extraClassDeclaration = [{ + /// Returns the list of tile sizes, which may be static (Attribute) or + /// dynamic (Value). + SmallVector getMixedSizes(); + }]; } def VectorizeOp : Op insertSlicesBack(OpBuilder &builder, Location loc, LinalgOp op, ValueRange operands, ValueRange results); +/// Turns an OpFoldResult into a value, creating an index-typed constant if +/// necessary. +Value materializeOpFoldResult(ImplicitLocOpBuilder &builder, + OpFoldResult opFoldResult); + /// Creates an extract_slice/subview op for a single `valueToTile` with /// `builder`. This new operation extracts a tile of `valueToTile`, starting /// at offsets `lbs` and with sizes `subShapeSizes`. `omitPartialTileCheck` diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index c43e5250b41e..ab35b06157b5 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/PDL/IR/PDL.h" @@ -103,16 +104,10 @@ transform::DecomposeOp::applyToOne(linalg::LinalgOp target, /// Apply a tiling transformation to all payload ops and store both the /// tiled operation as well as the created tile loops. static LogicalResult -applyTilingToAll(Operation *transformOp, Value target, - ArrayRef tileSizes, +applyTilingToAll(Operation *transformOp, ArrayRef payloadOps, + unsigned numLoops, transform::TransformResults &transformResults, - transform::TransformState &state, function_ref(LinalgOp)> applyFn) { - // Number of loops: Number of tiles sizes that are not zero. - size_t numLoops = tileSizes.size() - llvm::count(tileSizes, 0); - // All payload ops. These should all be LinalgOps for now. - ArrayRef payloadOps = state.getPayloadOps(target); - SmallVector tiledLinalgOps; SmallVector> loopOps(numLoops); for (unsigned int i = 0; i < numLoops; ++i) @@ -178,8 +173,9 @@ transform::FuseOp::apply(mlir::transform::TransformResults &transformResults, fusionOptions.tileInterchange = extractI64Array(getTileInterchange()); LogicalResult result = applyTilingToAll( - getOperation(), getTarget(), fusionOptions.tileSizes, transformResults, - state, [&](LinalgOp linalgOp) -> FailureOr { + getOperation(), state.getPayloadOps(getTarget()), + fusionOptions.tileSizes.size() - llvm::count(fusionOptions.tileSizes, 0), + transformResults, [&](LinalgOp linalgOp) -> FailureOr { LinalgTileAndFuseTensorOpsPattern pattern(getContext(), fusionOptions); SimpleRewriter rewriter(getContext()); rewriter.setInsertionPoint(linalgOp); @@ -194,8 +190,7 @@ transform::FuseOp::apply(mlir::transform::TransformResults &transformResults, tileLoopNest->getLoopOps().end()}; return tiledLinalgOp; }); - return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() - : DiagnosedSilenceableFailure::success(); + return DiagnosedSilenceableFailure(result); } ParseResult transform::FuseOp::parse(OpAsmParser &parser, @@ -603,32 +598,141 @@ DiagnosedSilenceableFailure transform::TileOp::apply(TransformResults &transformResults, TransformState &state) { LinalgTilingOptions tilingOptions; - SmallVector tileSizes = extractI64Array(getSizes()); + SmallVector tileSizes = extractI64Array(getStaticSizes()); - if (!tileSizes.empty()) - tilingOptions.setTileSizes(tileSizes); - tilingOptions.setInterchange(extractUIntArray(getInterchange())); - LinalgTilingPattern pattern(getContext(), tilingOptions); + ArrayRef targets = state.getPayloadOps(getTarget()); + SmallVector> dynamicSizeProducers; + dynamicSizeProducers.reserve(getDynamicSizes().size()); + for (Value dynamicSizeProducerHandle : getDynamicSizes()) { + dynamicSizeProducers.push_back( + state.getPayloadOps(dynamicSizeProducerHandle)); - LogicalResult result = applyTilingToAll( - getOperation(), getTarget(), tileSizes, transformResults, state, - [&](LinalgOp linalgOp) { - SimpleRewriter rewriter(linalgOp.getContext()); - return pattern.returningMatchAndRewrite(linalgOp, rewriter); - }); - return DiagnosedSilenceableFailure(result); + if (dynamicSizeProducers.back().size() != targets.size()) { + DiagnosedSilenceableFailure diag = + emitSilenceableError() + << "expected as many dynamic size-producing operations (" + << dynamicSizeProducers.back().size() << ") as target ops (" + << targets.size() << ")"; + diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle"; + return diag; + } + + for (Operation *op : dynamicSizeProducers.back()) { + if (op->getNumResults() == 1 && + op->getResult(0).getType().isa()) + continue; + DiagnosedSilenceableFailure diag = + emitSilenceableError() << "expected sizes to be produced by ops " + "with a single index-type result"; + diag.attachNote(op->getLoc()) << "size producer op"; + diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle"; + return diag; + } + } + + SmallVector tiled; + SmallVector, 4> loops; + loops.resize(getLoops().size()); + for (auto &en : llvm::enumerate(targets)) { + auto linalgOp = dyn_cast(en.value()); + if (!linalgOp) { + DiagnosedSilenceableFailure diag = emitSilenceableError() + << "only linalg ops are supported"; + diag.attachNote(en.value()->getLoc()) << "target op"; + return diag; + } + + unsigned index = en.index(); + if (!tileSizes.empty()) { + tilingOptions.setTileSizeComputationFunction( + [&, index](OpBuilder &b, Operation *) { + SmallVector sizes; + sizes.reserve(tileSizes.size()); + unsigned dynamicIdx = 0; + for (OpFoldResult ofr : getMixedSizes()) { + if (auto attr = ofr.dyn_cast()) { + sizes.push_back(b.create( + getLoc(), attr.cast().getInt())); + } else { + sizes.push_back( + dynamicSizeProducers[dynamicIdx++][index]->getResult(0)); + } + } + return sizes; + }); + } + + tilingOptions.setInterchange(extractUIntArray(getInterchange())); + LinalgTilingPattern pattern(getContext(), tilingOptions); + SimpleRewriter rewriter(linalgOp.getContext()); + FailureOr tiledOp = + pattern.returningMatchAndRewrite(linalgOp, rewriter); + if (failed(tiledOp)) + return DiagnosedSilenceableFailure::definiteFailure(); + + tiled.push_back(tiledOp->op); + for (const auto &en2 : llvm::enumerate(tiledOp->loops)) + loops[en2.index()].push_back(en2.value()); + } + + transformResults.set(getTiledLinalgOp().cast(), tiled); + for (const auto &en : llvm::enumerate(loops)) + transformResults.set(getLoops()[en.index()].cast(), en.value()); + + return DiagnosedSilenceableFailure::success(); +} + +SmallVector transform::TileOp::getMixedSizes() { + ValueRange dynamic = getDynamicSizes(); + SmallVector tileSizes = extractI64Array(getStaticSizes()); + SmallVector results; + results.reserve(tileSizes.size()); + unsigned dynamicPos = 0; + Builder builder(getContext()); + for (int64_t size : tileSizes) { + if (size == ShapedType::kDynamicSize) { + results.push_back(dynamic[dynamicPos++]); + } else { + results.push_back(builder.getIndexAttr(size)); + } + } + return results; } ParseResult transform::TileOp::parse(OpAsmParser &parser, OperationState &result) { - return parseTileLikeOp(parser, result, - TileOp::getSizesAttrName(result.name).getValue()); + OpAsmParser::UnresolvedOperand target; + SmallVector dynamicSizes; + ArrayAttr staticSizes; + auto pdlOperationType = pdl::OperationType::get(parser.getContext()); + if (parser.parseOperand(target) || + parser.resolveOperand(target, pdlOperationType, result.operands) || + parseOperandsOrIntegersSizesList(parser, dynamicSizes, staticSizes) || + parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands) || + parser.parseOptionalAttrDict(result.attributes)) + return ParseResult::failure(); + + result.addAttribute(getStaticSizesAttrName(result.name), staticSizes); + size_t numExpectedLoops = + staticSizes.size() - llvm::count(extractI64Array(staticSizes), 0); + result.addTypes(SmallVector(numExpectedLoops + 1, pdlOperationType)); + return success(); } void TileOp::print(OpAsmPrinter &p) { - p << ' '; - p << getTarget(); - p.printOptionalAttrDict((*this)->getAttrs()); + p << ' ' << getTarget(); + printOperandsOrIntegersSizesList(p, getOperation(), getDynamicSizes(), + getStaticSizes()); + p.printOptionalAttrDict((*this)->getAttrs(), {getStaticSizesAttrName()}); +} + +void transform::TileOp::getEffects( + SmallVectorImpl &effects) { + consumesHandle(getTarget(), effects); + onlyReadsHandle(getDynamicSizes(), effects); + producesHandle(getTiledLinalgOp(), effects); + producesHandle(getLoops(), effects); + modifiesPayload(effects); } //===----------------------------------------------------------------------===// @@ -678,6 +782,7 @@ class LinalgTransformDialectExtension LinalgTransformDialectExtension> { public: LinalgTransformDialectExtension() { + declareDependentDialect(); declareDependentDialect(); declareDependentDialect(); declareDependentDialect(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp index 7d6fb66041d3..23257713e4e7 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp @@ -15,16 +15,6 @@ using namespace mlir; using namespace mlir::linalg; -/// Turns an OpFoldResult into a value, creating an index-typed constant if -/// necessary. -static Value materializeOpFoldResult(ImplicitLocOpBuilder &builder, - OpFoldResult opFoldResult) { - if (opFoldResult.is()) - return opFoldResult.get(); - auto attr = opFoldResult.get().cast(); - return builder.create(attr.getValue().getSExtValue()); -} - /// Extract the slices of `operands` supplied to the given operation `op` such /// that they are sufficient to execute the op for the subset of its iteration /// space defined by `splitIterationSpace`. The subset is a part of the original diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index e8eaf38acac8..34b6714f0747 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -993,6 +993,14 @@ SmallVector insertSlicesBack(OpBuilder &builder, Location loc, return tensorResults; } +Value materializeOpFoldResult(ImplicitLocOpBuilder &builder, + OpFoldResult opFoldResult) { + if (auto value = opFoldResult.dyn_cast()) + return value; + auto attr = opFoldResult.get().cast(); + return builder.create(attr.getValue().getSExtValue()); +} + SmallVector makeTiledShapes(OpBuilder &b, Location loc, LinalgOp linalgOp, ArrayRef valuesToTile, diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py index beef9240d8e3..b6e078fc78b3 100644 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -191,18 +191,40 @@ class TileOp: def __init__(self, target: Union[Operation, Value], *, - sizes: OptionalIntList = None, + sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation, + Value]], ArrayAttr]] = None, interchange: OptionalIntList = None, loc=None, ip=None): pdl_operation_type = pdl.OperationType.get() - sizes_attr = _get_int_array_attr(sizes) + i64_type = IntegerType.get_signless(64) + + if sizes is None: + sizes = [] + + static_sizes = [] + dynamic_sizes = [] + if isinstance(sizes, ArrayAttr): + sizes_attr = sizes + else: + for size in sizes: + if isinstance(size, int): + static_sizes.append(IntegerAttr.get(i64_type, size)) + elif isinstance(size, IntegerAttr): + static_sizes.append(size) + else: + static_sizes.append( + IntegerAttr.get(i64_type, ShapedType._get_dynamic_size())) + dynamic_sizes.append(_get_op_result_or_value(size)) + sizes_attr = ArrayAttr.get(static_sizes) + num_loops = sum( v if v == 0 else 1 for v in self.__extract_values(sizes_attr)) super().__init__( pdl_operation_type, [pdl_operation_type] * num_loops, _get_op_result_or_value(target), - sizes=sizes_attr, + dynamic_sizes=dynamic_sizes, + static_sizes=sizes_attr, interchange=_get_int_array_attr(interchange) if interchange else None, loc=loc, ip=ip) diff --git a/mlir/test/Dialect/Linalg/transform-op-scalarize.mlir b/mlir/test/Dialect/Linalg/transform-op-scalarize.mlir index ab25777adeef..234e8ba4ef81 100644 --- a/mlir/test/Dialect/Linalg/transform-op-scalarize.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-scalarize.mlir @@ -23,7 +23,7 @@ transform.with_pdl_patterns { transform.sequence %arg0 { ^bb1(%arg1: !pdl.operation): %0 = pdl_match @pdl_target in %arg1 - %1, %loops = transform.structured.tile %0 {sizes = [10, 0, 0]} + %1, %loops = transform.structured.tile %0 [10, 0, 0] %2 = transform.structured.scalarize %1 } } diff --git a/mlir/test/Dialect/Linalg/transform-op-tile.mlir b/mlir/test/Dialect/Linalg/transform-op-tile.mlir index a5310944357d..61c79c090643 100644 --- a/mlir/test/Dialect/Linalg/transform-op-tile.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-tile.mlir @@ -1,11 +1,11 @@ -// RUN: mlir-opt --test-transform-dialect-interpreter %s | FileCheck %s +// RUN: mlir-opt --test-transform-dialect-interpreter --split-input-file %s | FileCheck %s transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): sequence %arg0 { ^bb0(%arg1: !pdl.operation): %0 = pdl_match @pdl_target in %arg1 - %1, %loops:3 = transform.structured.tile %0 {sizes = [4, 4, 4]} + %1, %loops:3 = transform.structured.tile %0 [4, 4, 4] } pdl.pattern @pdl_target : benefit(1) { @@ -44,3 +44,58 @@ func.func @tile_linalg_matmul( return %0 : tensor<128x128xf32> } +// ----- + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + sequence %arg0 { + ^bb0(%arg1: !pdl.operation): + %0 = pdl_match @pdl_target in %arg1 + %1 = pdl_match @func_call in %arg1 + %2, %loops:3 = transform.structured.tile %0 [%1, %1, 4] + } + + pdl.pattern @pdl_target : benefit(1) { + %args = operands + %results = types + %0 = operation "linalg.matmul"(%args : !pdl.range) -> (%results : !pdl.range) + rewrite %0 with "transform.dialect" + } + pdl.pattern @func_call : benefit(1) { + %args = operands + %results = types + %0 = operation "func.call"(%args : !pdl.range) -> (%results : !pdl.range) + rewrite %0 with "transform.dialect" + } +} + +func.func private @get_dynamic_tile_size() -> index + +// CHECK-LABEL: func @tile_linalg_matmul_dynamic( +// CHECK-SAME: %[[TA:[0-9a-z]+]]: tensor<128x128xf32> +// CHECK-SAME: %[[TB:[0-9a-z]+]]: tensor<128x128xf32> +// CHECK-SAME: %[[TC:[0-9a-z]+]]: tensor<128x128xf32> +// CHECK-SAME: -> tensor<128x128xf32> { +func.func @tile_linalg_matmul_dynamic( + %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>) + -> tensor<128x128xf32> { +// CHECK: %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[TC]]) -> (tensor<128x128xf32>) { +// CHECK: %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor<128x128xf32>) { +// CHECK: %[[TD2:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC2:.*]] = %[[TC1]]) -> (tensor<128x128xf32>) { +// CHECK: %[[sTA:.*]] = tensor.extract_slice %[[TA]][{{.*}}] : tensor<128x128xf32> to tensor +// CHECK: %[[sTB:.*]] = tensor.extract_slice %[[TB]][{{.*}}] : tensor<128x128xf32> to tensor<4x?xf32> +// CHECK: %[[sTC:.*]] = tensor.extract_slice %[[TC2]][{{.*}}] : tensor<128x128xf32> to tensor +// CHECK: %[[sTD:.*]] = linalg.matmul ins(%[[sTA]], %[[sTB]] : tensor, tensor<4x?xf32>) +// CHECK-SAME: outs(%[[sTC]] : tensor) -> tensor +// CHECK: %[[TD:.*]] = tensor.insert_slice %[[sTD]] into %[[TC2]][{{.*}}] : tensor into tensor<128x128xf32> +// CHECK: scf.yield %[[TD]] : tensor<128x128xf32> +// CHECK: scf.yield %[[TD2]] : tensor<128x128xf32> +// CHECK: scf.yield %[[TD1]] : tensor<128x128xf32> + %sz = func.call @get_dynamic_tile_size() : () -> index + %0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>) + outs(%arg2: tensor<128x128xf32>) + -> tensor<128x128xf32> + +// CHECK: return %[[TD0]] : tensor<128x128xf32> + return %0 : tensor<128x128xf32> +} diff --git a/mlir/test/Dialect/Linalg/transform-ops.mlir b/mlir/test/Dialect/Linalg/transform-ops.mlir index ae01f3d571d3..a7d2e1e2ba82 100644 --- a/mlir/test/Dialect/Linalg/transform-ops.mlir +++ b/mlir/test/Dialect/Linalg/transform-ops.mlir @@ -3,7 +3,7 @@ transform.sequence { ^bb1(%arg0: !pdl.operation): // CHECK %{{.*}}, %{{.*}}:2 = transform.structured.tile - %0, %1:2 = transform.structured.tile %arg0 { sizes = [2, 0, 3] } + %0, %1:2 = transform.structured.tile %arg0 [2, 0, 3] } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Transform/selective-targeting.mlir b/mlir/test/Dialect/Transform/selective-targeting.mlir index eee4c2bf5445..397a80aeff5e 100644 --- a/mlir/test/Dialect/Transform/selective-targeting.mlir +++ b/mlir/test/Dialect/Transform/selective-targeting.mlir @@ -77,7 +77,7 @@ transform.with_pdl_patterns { transform.sequence %arg0 { ^bb1(%arg1: !pdl.operation): %0 = pdl_match @pdl_target_attrA in %arg1 - transform.structured.tile %0 {sizes = [4, 4, 4]} + transform.structured.tile %0 [4, 4, 4] %1 = pdl_match @pdl_target_attrC in %arg1 %2 = transform.get_closest_isolated_parent %1 transform.structured.vectorize %2 diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py index f7838f7e2adb..cd4412f92f19 100644 --- a/mlir/test/python/dialects/transform_structured_ext.py +++ b/mlir/test/python/dialects/transform_structured_ext.py @@ -105,9 +105,8 @@ def testTileCompact(): transform.YieldOp() # CHECK-LABEL: TEST: testTileCompact # CHECK: transform.sequence - # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile - # CHECK-DAG: interchange = [0, 1] - # CHECK-DAG: sizes = [4, 8] + # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 8] + # CHECK: interchange = [0, 1] @run @@ -122,9 +121,8 @@ def testTileAttributes(): transform.YieldOp() # CHECK-LABEL: TEST: testTileAttributes # CHECK: transform.sequence - # CHECK: structured.tile - # CHECK-DAG: interchange = [0, 1] - # CHECK-DAG: sizes = [4, 8] + # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 8] + # CHECK: interchange = [0, 1] @run @@ -136,9 +134,24 @@ def testTileZero(): transform.YieldOp() # CHECK-LABEL: TEST: testTileZero # CHECK: transform.sequence - # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile - # CHECK-DAG: interchange = [0, 1, 2, 3] - # CHECK-DAG: sizes = [4, 0, 2, 0] + # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 0, 2, 0] + # CHECK: interchange = [0, 1, 2, 3] + + +@run +def testTileDynamic(): + with_pdl = transform.WithPDLPatternsOp() + with InsertionPoint(with_pdl.body): + sequence = transform.SequenceOp(with_pdl.bodyTarget) + with InsertionPoint(sequence.body): + m1 = transform.PDLMatchOp(sequence.bodyTarget, "first") + m2 = transform.PDLMatchOp(sequence.bodyTarget, "second") + structured.TileOp(sequence.bodyTarget, sizes=[m1, 3, m2, 0]) + transform.YieldOp() + # CHECK-LABEL: TEST: testTileDynamic + # CHECK: %[[FIRST:.+]] = pdl_match + # CHECK: %[[SECOND:.+]] = pdl_match + # CHECK: %{{.+}}, %{{.+}}:3 = transform.structured.tile %{{.*}}[%[[FIRST]], 3, %[[SECOND]], 0] @run diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index e66f4fd7e452..e86b41c05bd2 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -7461,6 +7461,7 @@ cc_library( ], includes = ["include"], deps = [ + ":ArithmeticDialect", ":IR", ":LinalgDialect", ":LinalgTransformOpsIncGen",