forked from OSchip/llvm-project
[mlir] Allow Tile transform op to take dynamic sizes
Extend the definition of the Tile structured transform op to enable it accepting handles to operations that produce tile sizes at runtime. This is useful by itself and prepares for more advanced tiling strategies. Note that the changes are relevant only to the transform dialect, the tiling transformation itself already supports dynamic sizes. Depends On D129216 Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D129217
This commit is contained in:
parent
7b69843f0b
commit
4e4a4c0576
|
@ -396,28 +396,59 @@ def SplitReductionOp : Op<Transform_Dialect, "structured.split_reduction",
|
|||
|
||||
def TileOp : Op<Transform_Dialect, "structured.tile",
|
||||
[DeclareOpInterfaceMethods<TransformOpInterface>,
|
||||
FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface]> {
|
||||
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
|
||||
let description = [{
|
||||
Indicates that the given `target` op should be tiled with the options
|
||||
provided as attributes. This transform generates a loop nest with a smaller
|
||||
("tiled") target operation in its body. Currently limited to LinalgOps.
|
||||
Indicates that the given `target` op should be tiled with the given sizes.
|
||||
This transform generates a loop nest with a smaller ("tiled") target
|
||||
operation in its body. Currently limited to LinalgOps.
|
||||
|
||||
`sizes` are the tile sizes. A tile size of `0` indicates that the
|
||||
respective dimension should not be tiled. No loop will be generated for such
|
||||
dimensions. If all tile sizes are `0`, this transform is effectively a
|
||||
no-op.
|
||||
Tile sizes may be known at transformation time, in which case they are
|
||||
expected to be provided in the `static_size` attribute, or not, in which
|
||||
case the tile value must be computed by the payload IR and the handle to the
|
||||
operation computing it must be provided through `dynamic_sizes`. When the
|
||||
sizes are not known statically, the corresponding entry in the
|
||||
`static_sizes` attribute must be set to `ShapedType::kDynamicSize`. Only
|
||||
the dynamic sizes must be provided in `dynamic_sizes`, i.e., there should
|
||||
be as many handles as `ShapedType::kDynamicSize` values in the
|
||||
`static_sizes` attribute. A static size of `0` indicates that the dimension
|
||||
should not be tiled. No loop will be generated for such dimensions. If all
|
||||
tile sizes are `0`, this transform is effectively a no-op.
|
||||
|
||||
This op returns handles to the tiled op (in the generated loop nest) and the
|
||||
generated loops. The number of loops is the number of non-zero tile sizes.
|
||||
generated loops. The number of loops is the number of tile sizes that are
|
||||
statically known to be non-zero.
|
||||
|
||||
#### Return modes
|
||||
|
||||
On success, the resulting handles are associated with co-indexed lists of
|
||||
tiled operations and loops around them.
|
||||
|
||||
This operation only supports Linalg ops and produces a silenceable failure
|
||||
if the input contains any non-Linalg ops. The ops preceding it in the list
|
||||
associated with the `target` handle will have been tiled.
|
||||
|
||||
This operation produces a silenceable failure if the `dynamic_sizes` handles
|
||||
are associated with lists of payload operations of a size different than
|
||||
that of the list associated with the `target` handle.
|
||||
|
||||
If the internal implementation of tiling for any of the operations fails,
|
||||
produces a definite failure.
|
||||
}];
|
||||
|
||||
let arguments = (ins PDL_Operation:$target,
|
||||
DefaultValuedAttr<I64ArrayAttr, "{}">:$sizes,
|
||||
Variadic<PDL_Operation>:$dynamic_sizes,
|
||||
DefaultValuedAttr<I64ArrayAttr, "{}">:$static_sizes,
|
||||
DefaultValuedAttr<I64ArrayAttr, "{}">:$interchange);
|
||||
let results = (outs PDL_Operation:$tiled_linalg_op,
|
||||
Variadic<PDL_Operation>:$loops);
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// Returns the list of tile sizes, which may be static (Attribute) or
|
||||
/// dynamic (Value).
|
||||
SmallVector<OpFoldResult> getMixedSizes();
|
||||
}];
|
||||
}
|
||||
|
||||
def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
|
||||
|
|
|
@ -210,6 +210,11 @@ SmallVector<Value> insertSlicesBack(OpBuilder &builder, Location loc,
|
|||
LinalgOp op, ValueRange operands,
|
||||
ValueRange results);
|
||||
|
||||
/// Turns an OpFoldResult into a value, creating an index-typed constant if
|
||||
/// necessary.
|
||||
Value materializeOpFoldResult(ImplicitLocOpBuilder &builder,
|
||||
OpFoldResult opFoldResult);
|
||||
|
||||
/// Creates an extract_slice/subview op for a single `valueToTile` with
|
||||
/// `builder`. This new operation extracts a tile of `valueToTile`, starting
|
||||
/// at offsets `lbs` and with sizes `subShapeSizes`. `omitPartialTileCheck`
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
|
||||
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
|
||||
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/PDL/IR/PDL.h"
|
||||
|
@ -103,16 +104,10 @@ transform::DecomposeOp::applyToOne(linalg::LinalgOp target,
|
|||
/// Apply a tiling transformation to all payload ops and store both the
|
||||
/// tiled operation as well as the created tile loops.
|
||||
static LogicalResult
|
||||
applyTilingToAll(Operation *transformOp, Value target,
|
||||
ArrayRef<int64_t> tileSizes,
|
||||
applyTilingToAll(Operation *transformOp, ArrayRef<Operation *> payloadOps,
|
||||
unsigned numLoops,
|
||||
transform::TransformResults &transformResults,
|
||||
transform::TransformState &state,
|
||||
function_ref<FailureOr<TiledLinalgOp>(LinalgOp)> applyFn) {
|
||||
// Number of loops: Number of tiles sizes that are not zero.
|
||||
size_t numLoops = tileSizes.size() - llvm::count(tileSizes, 0);
|
||||
// All payload ops. These should all be LinalgOps for now.
|
||||
ArrayRef<Operation *> payloadOps = state.getPayloadOps(target);
|
||||
|
||||
SmallVector<Operation *> tiledLinalgOps;
|
||||
SmallVector<SmallVector<Operation *>> loopOps(numLoops);
|
||||
for (unsigned int i = 0; i < numLoops; ++i)
|
||||
|
@ -178,8 +173,9 @@ transform::FuseOp::apply(mlir::transform::TransformResults &transformResults,
|
|||
fusionOptions.tileInterchange = extractI64Array(getTileInterchange());
|
||||
|
||||
LogicalResult result = applyTilingToAll(
|
||||
getOperation(), getTarget(), fusionOptions.tileSizes, transformResults,
|
||||
state, [&](LinalgOp linalgOp) -> FailureOr<TiledLinalgOp> {
|
||||
getOperation(), state.getPayloadOps(getTarget()),
|
||||
fusionOptions.tileSizes.size() - llvm::count(fusionOptions.tileSizes, 0),
|
||||
transformResults, [&](LinalgOp linalgOp) -> FailureOr<TiledLinalgOp> {
|
||||
LinalgTileAndFuseTensorOpsPattern pattern(getContext(), fusionOptions);
|
||||
SimpleRewriter rewriter(getContext());
|
||||
rewriter.setInsertionPoint(linalgOp);
|
||||
|
@ -194,8 +190,7 @@ transform::FuseOp::apply(mlir::transform::TransformResults &transformResults,
|
|||
tileLoopNest->getLoopOps().end()};
|
||||
return tiledLinalgOp;
|
||||
});
|
||||
return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
|
||||
: DiagnosedSilenceableFailure::success();
|
||||
return DiagnosedSilenceableFailure(result);
|
||||
}
|
||||
|
||||
ParseResult transform::FuseOp::parse(OpAsmParser &parser,
|
||||
|
@ -603,32 +598,141 @@ DiagnosedSilenceableFailure
|
|||
transform::TileOp::apply(TransformResults &transformResults,
|
||||
TransformState &state) {
|
||||
LinalgTilingOptions tilingOptions;
|
||||
SmallVector<int64_t> tileSizes = extractI64Array(getSizes());
|
||||
SmallVector<int64_t> tileSizes = extractI64Array(getStaticSizes());
|
||||
|
||||
if (!tileSizes.empty())
|
||||
tilingOptions.setTileSizes(tileSizes);
|
||||
tilingOptions.setInterchange(extractUIntArray(getInterchange()));
|
||||
LinalgTilingPattern pattern(getContext(), tilingOptions);
|
||||
ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
|
||||
SmallVector<ArrayRef<Operation *>> dynamicSizeProducers;
|
||||
dynamicSizeProducers.reserve(getDynamicSizes().size());
|
||||
for (Value dynamicSizeProducerHandle : getDynamicSizes()) {
|
||||
dynamicSizeProducers.push_back(
|
||||
state.getPayloadOps(dynamicSizeProducerHandle));
|
||||
|
||||
LogicalResult result = applyTilingToAll(
|
||||
getOperation(), getTarget(), tileSizes, transformResults, state,
|
||||
[&](LinalgOp linalgOp) {
|
||||
SimpleRewriter rewriter(linalgOp.getContext());
|
||||
return pattern.returningMatchAndRewrite(linalgOp, rewriter);
|
||||
});
|
||||
return DiagnosedSilenceableFailure(result);
|
||||
if (dynamicSizeProducers.back().size() != targets.size()) {
|
||||
DiagnosedSilenceableFailure diag =
|
||||
emitSilenceableError()
|
||||
<< "expected as many dynamic size-producing operations ("
|
||||
<< dynamicSizeProducers.back().size() << ") as target ops ("
|
||||
<< targets.size() << ")";
|
||||
diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle";
|
||||
return diag;
|
||||
}
|
||||
|
||||
for (Operation *op : dynamicSizeProducers.back()) {
|
||||
if (op->getNumResults() == 1 &&
|
||||
op->getResult(0).getType().isa<IndexType>())
|
||||
continue;
|
||||
DiagnosedSilenceableFailure diag =
|
||||
emitSilenceableError() << "expected sizes to be produced by ops "
|
||||
"with a single index-type result";
|
||||
diag.attachNote(op->getLoc()) << "size producer op";
|
||||
diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle";
|
||||
return diag;
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<Operation *> tiled;
|
||||
SmallVector<SmallVector<Operation *, 4>, 4> loops;
|
||||
loops.resize(getLoops().size());
|
||||
for (auto &en : llvm::enumerate(targets)) {
|
||||
auto linalgOp = dyn_cast<LinalgOp>(en.value());
|
||||
if (!linalgOp) {
|
||||
DiagnosedSilenceableFailure diag = emitSilenceableError()
|
||||
<< "only linalg ops are supported";
|
||||
diag.attachNote(en.value()->getLoc()) << "target op";
|
||||
return diag;
|
||||
}
|
||||
|
||||
unsigned index = en.index();
|
||||
if (!tileSizes.empty()) {
|
||||
tilingOptions.setTileSizeComputationFunction(
|
||||
[&, index](OpBuilder &b, Operation *) {
|
||||
SmallVector<Value, 4> sizes;
|
||||
sizes.reserve(tileSizes.size());
|
||||
unsigned dynamicIdx = 0;
|
||||
for (OpFoldResult ofr : getMixedSizes()) {
|
||||
if (auto attr = ofr.dyn_cast<Attribute>()) {
|
||||
sizes.push_back(b.create<arith::ConstantIndexOp>(
|
||||
getLoc(), attr.cast<IntegerAttr>().getInt()));
|
||||
} else {
|
||||
sizes.push_back(
|
||||
dynamicSizeProducers[dynamicIdx++][index]->getResult(0));
|
||||
}
|
||||
}
|
||||
return sizes;
|
||||
});
|
||||
}
|
||||
|
||||
tilingOptions.setInterchange(extractUIntArray(getInterchange()));
|
||||
LinalgTilingPattern pattern(getContext(), tilingOptions);
|
||||
SimpleRewriter rewriter(linalgOp.getContext());
|
||||
FailureOr<TiledLinalgOp> tiledOp =
|
||||
pattern.returningMatchAndRewrite(linalgOp, rewriter);
|
||||
if (failed(tiledOp))
|
||||
return DiagnosedSilenceableFailure::definiteFailure();
|
||||
|
||||
tiled.push_back(tiledOp->op);
|
||||
for (const auto &en2 : llvm::enumerate(tiledOp->loops))
|
||||
loops[en2.index()].push_back(en2.value());
|
||||
}
|
||||
|
||||
transformResults.set(getTiledLinalgOp().cast<OpResult>(), tiled);
|
||||
for (const auto &en : llvm::enumerate(loops))
|
||||
transformResults.set(getLoops()[en.index()].cast<OpResult>(), en.value());
|
||||
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> transform::TileOp::getMixedSizes() {
|
||||
ValueRange dynamic = getDynamicSizes();
|
||||
SmallVector<int64_t> tileSizes = extractI64Array(getStaticSizes());
|
||||
SmallVector<OpFoldResult> results;
|
||||
results.reserve(tileSizes.size());
|
||||
unsigned dynamicPos = 0;
|
||||
Builder builder(getContext());
|
||||
for (int64_t size : tileSizes) {
|
||||
if (size == ShapedType::kDynamicSize) {
|
||||
results.push_back(dynamic[dynamicPos++]);
|
||||
} else {
|
||||
results.push_back(builder.getIndexAttr(size));
|
||||
}
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
ParseResult transform::TileOp::parse(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
return parseTileLikeOp(parser, result,
|
||||
TileOp::getSizesAttrName(result.name).getValue());
|
||||
OpAsmParser::UnresolvedOperand target;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizes;
|
||||
ArrayAttr staticSizes;
|
||||
auto pdlOperationType = pdl::OperationType::get(parser.getContext());
|
||||
if (parser.parseOperand(target) ||
|
||||
parser.resolveOperand(target, pdlOperationType, result.operands) ||
|
||||
parseOperandsOrIntegersSizesList(parser, dynamicSizes, staticSizes) ||
|
||||
parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands) ||
|
||||
parser.parseOptionalAttrDict(result.attributes))
|
||||
return ParseResult::failure();
|
||||
|
||||
result.addAttribute(getStaticSizesAttrName(result.name), staticSizes);
|
||||
size_t numExpectedLoops =
|
||||
staticSizes.size() - llvm::count(extractI64Array(staticSizes), 0);
|
||||
result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOperationType));
|
||||
return success();
|
||||
}
|
||||
|
||||
void TileOp::print(OpAsmPrinter &p) {
|
||||
p << ' ';
|
||||
p << getTarget();
|
||||
p.printOptionalAttrDict((*this)->getAttrs());
|
||||
p << ' ' << getTarget();
|
||||
printOperandsOrIntegersSizesList(p, getOperation(), getDynamicSizes(),
|
||||
getStaticSizes());
|
||||
p.printOptionalAttrDict((*this)->getAttrs(), {getStaticSizesAttrName()});
|
||||
}
|
||||
|
||||
void transform::TileOp::getEffects(
|
||||
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
||||
consumesHandle(getTarget(), effects);
|
||||
onlyReadsHandle(getDynamicSizes(), effects);
|
||||
producesHandle(getTiledLinalgOp(), effects);
|
||||
producesHandle(getLoops(), effects);
|
||||
modifiesPayload(effects);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -678,6 +782,7 @@ class LinalgTransformDialectExtension
|
|||
LinalgTransformDialectExtension> {
|
||||
public:
|
||||
LinalgTransformDialectExtension() {
|
||||
declareDependentDialect<arith::ArithmeticDialect>();
|
||||
declareDependentDialect<pdl::PDLDialect>();
|
||||
declareDependentDialect<scf::SCFDialect>();
|
||||
declareDependentDialect<vector::VectorDialect>();
|
||||
|
|
|
@ -15,16 +15,6 @@
|
|||
using namespace mlir;
|
||||
using namespace mlir::linalg;
|
||||
|
||||
/// Turns an OpFoldResult into a value, creating an index-typed constant if
|
||||
/// necessary.
|
||||
static Value materializeOpFoldResult(ImplicitLocOpBuilder &builder,
|
||||
OpFoldResult opFoldResult) {
|
||||
if (opFoldResult.is<Value>())
|
||||
return opFoldResult.get<Value>();
|
||||
auto attr = opFoldResult.get<Attribute>().cast<IntegerAttr>();
|
||||
return builder.create<arith::ConstantIndexOp>(attr.getValue().getSExtValue());
|
||||
}
|
||||
|
||||
/// Extract the slices of `operands` supplied to the given operation `op` such
|
||||
/// that they are sufficient to execute the op for the subset of its iteration
|
||||
/// space defined by `splitIterationSpace`. The subset is a part of the original
|
||||
|
|
|
@ -993,6 +993,14 @@ SmallVector<Value> insertSlicesBack(OpBuilder &builder, Location loc,
|
|||
return tensorResults;
|
||||
}
|
||||
|
||||
Value materializeOpFoldResult(ImplicitLocOpBuilder &builder,
|
||||
OpFoldResult opFoldResult) {
|
||||
if (auto value = opFoldResult.dyn_cast<Value>())
|
||||
return value;
|
||||
auto attr = opFoldResult.get<Attribute>().cast<IntegerAttr>();
|
||||
return builder.create<arith::ConstantIndexOp>(attr.getValue().getSExtValue());
|
||||
}
|
||||
|
||||
SmallVector<Value, 4> makeTiledShapes(OpBuilder &b, Location loc,
|
||||
LinalgOp linalgOp,
|
||||
ArrayRef<Value> valuesToTile,
|
||||
|
|
|
@ -191,18 +191,40 @@ class TileOp:
|
|||
def __init__(self,
|
||||
target: Union[Operation, Value],
|
||||
*,
|
||||
sizes: OptionalIntList = None,
|
||||
sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation,
|
||||
Value]], ArrayAttr]] = None,
|
||||
interchange: OptionalIntList = None,
|
||||
loc=None,
|
||||
ip=None):
|
||||
pdl_operation_type = pdl.OperationType.get()
|
||||
sizes_attr = _get_int_array_attr(sizes)
|
||||
i64_type = IntegerType.get_signless(64)
|
||||
|
||||
if sizes is None:
|
||||
sizes = []
|
||||
|
||||
static_sizes = []
|
||||
dynamic_sizes = []
|
||||
if isinstance(sizes, ArrayAttr):
|
||||
sizes_attr = sizes
|
||||
else:
|
||||
for size in sizes:
|
||||
if isinstance(size, int):
|
||||
static_sizes.append(IntegerAttr.get(i64_type, size))
|
||||
elif isinstance(size, IntegerAttr):
|
||||
static_sizes.append(size)
|
||||
else:
|
||||
static_sizes.append(
|
||||
IntegerAttr.get(i64_type, ShapedType._get_dynamic_size()))
|
||||
dynamic_sizes.append(_get_op_result_or_value(size))
|
||||
sizes_attr = ArrayAttr.get(static_sizes)
|
||||
|
||||
num_loops = sum(
|
||||
v if v == 0 else 1 for v in self.__extract_values(sizes_attr))
|
||||
super().__init__(
|
||||
pdl_operation_type, [pdl_operation_type] * num_loops,
|
||||
_get_op_result_or_value(target),
|
||||
sizes=sizes_attr,
|
||||
dynamic_sizes=dynamic_sizes,
|
||||
static_sizes=sizes_attr,
|
||||
interchange=_get_int_array_attr(interchange) if interchange else None,
|
||||
loc=loc,
|
||||
ip=ip)
|
||||
|
|
|
@ -23,7 +23,7 @@ transform.with_pdl_patterns {
|
|||
transform.sequence %arg0 {
|
||||
^bb1(%arg1: !pdl.operation):
|
||||
%0 = pdl_match @pdl_target in %arg1
|
||||
%1, %loops = transform.structured.tile %0 {sizes = [10, 0, 0]}
|
||||
%1, %loops = transform.structured.tile %0 [10, 0, 0]
|
||||
%2 = transform.structured.scalarize %1
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
// RUN: mlir-opt --test-transform-dialect-interpreter %s | FileCheck %s
|
||||
// RUN: mlir-opt --test-transform-dialect-interpreter --split-input-file %s | FileCheck %s
|
||||
|
||||
transform.with_pdl_patterns {
|
||||
^bb0(%arg0: !pdl.operation):
|
||||
sequence %arg0 {
|
||||
^bb0(%arg1: !pdl.operation):
|
||||
%0 = pdl_match @pdl_target in %arg1
|
||||
%1, %loops:3 = transform.structured.tile %0 {sizes = [4, 4, 4]}
|
||||
%1, %loops:3 = transform.structured.tile %0 [4, 4, 4]
|
||||
}
|
||||
|
||||
pdl.pattern @pdl_target : benefit(1) {
|
||||
|
@ -44,3 +44,58 @@ func.func @tile_linalg_matmul(
|
|||
return %0 : tensor<128x128xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
transform.with_pdl_patterns {
|
||||
^bb0(%arg0: !pdl.operation):
|
||||
sequence %arg0 {
|
||||
^bb0(%arg1: !pdl.operation):
|
||||
%0 = pdl_match @pdl_target in %arg1
|
||||
%1 = pdl_match @func_call in %arg1
|
||||
%2, %loops:3 = transform.structured.tile %0 [%1, %1, 4]
|
||||
}
|
||||
|
||||
pdl.pattern @pdl_target : benefit(1) {
|
||||
%args = operands
|
||||
%results = types
|
||||
%0 = operation "linalg.matmul"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
|
||||
rewrite %0 with "transform.dialect"
|
||||
}
|
||||
pdl.pattern @func_call : benefit(1) {
|
||||
%args = operands
|
||||
%results = types
|
||||
%0 = operation "func.call"(%args : !pdl.range<value>) -> (%results : !pdl.range<type>)
|
||||
rewrite %0 with "transform.dialect"
|
||||
}
|
||||
}
|
||||
|
||||
func.func private @get_dynamic_tile_size() -> index
|
||||
|
||||
// CHECK-LABEL: func @tile_linalg_matmul_dynamic(
|
||||
// CHECK-SAME: %[[TA:[0-9a-z]+]]: tensor<128x128xf32>
|
||||
// CHECK-SAME: %[[TB:[0-9a-z]+]]: tensor<128x128xf32>
|
||||
// CHECK-SAME: %[[TC:[0-9a-z]+]]: tensor<128x128xf32>
|
||||
// CHECK-SAME: -> tensor<128x128xf32> {
|
||||
func.func @tile_linalg_matmul_dynamic(
|
||||
%arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>)
|
||||
-> tensor<128x128xf32> {
|
||||
// CHECK: %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[TC]]) -> (tensor<128x128xf32>) {
|
||||
// CHECK: %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor<128x128xf32>) {
|
||||
// CHECK: %[[TD2:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC2:.*]] = %[[TC1]]) -> (tensor<128x128xf32>) {
|
||||
// CHECK: %[[sTA:.*]] = tensor.extract_slice %[[TA]][{{.*}}] : tensor<128x128xf32> to tensor<?x4xf32>
|
||||
// CHECK: %[[sTB:.*]] = tensor.extract_slice %[[TB]][{{.*}}] : tensor<128x128xf32> to tensor<4x?xf32>
|
||||
// CHECK: %[[sTC:.*]] = tensor.extract_slice %[[TC2]][{{.*}}] : tensor<128x128xf32> to tensor<?x?xf32>
|
||||
// CHECK: %[[sTD:.*]] = linalg.matmul ins(%[[sTA]], %[[sTB]] : tensor<?x4xf32>, tensor<4x?xf32>)
|
||||
// CHECK-SAME: outs(%[[sTC]] : tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[TD:.*]] = tensor.insert_slice %[[sTD]] into %[[TC2]][{{.*}}] : tensor<?x?xf32> into tensor<128x128xf32>
|
||||
// CHECK: scf.yield %[[TD]] : tensor<128x128xf32>
|
||||
// CHECK: scf.yield %[[TD2]] : tensor<128x128xf32>
|
||||
// CHECK: scf.yield %[[TD1]] : tensor<128x128xf32>
|
||||
%sz = func.call @get_dynamic_tile_size() : () -> index
|
||||
%0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
|
||||
outs(%arg2: tensor<128x128xf32>)
|
||||
-> tensor<128x128xf32>
|
||||
|
||||
// CHECK: return %[[TD0]] : tensor<128x128xf32>
|
||||
return %0 : tensor<128x128xf32>
|
||||
}
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
transform.sequence {
|
||||
^bb1(%arg0: !pdl.operation):
|
||||
// CHECK %{{.*}}, %{{.*}}:2 = transform.structured.tile
|
||||
%0, %1:2 = transform.structured.tile %arg0 { sizes = [2, 0, 3] }
|
||||
%0, %1:2 = transform.structured.tile %arg0 [2, 0, 3]
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -77,7 +77,7 @@ transform.with_pdl_patterns {
|
|||
transform.sequence %arg0 {
|
||||
^bb1(%arg1: !pdl.operation):
|
||||
%0 = pdl_match @pdl_target_attrA in %arg1
|
||||
transform.structured.tile %0 {sizes = [4, 4, 4]}
|
||||
transform.structured.tile %0 [4, 4, 4]
|
||||
%1 = pdl_match @pdl_target_attrC in %arg1
|
||||
%2 = transform.get_closest_isolated_parent %1
|
||||
transform.structured.vectorize %2
|
||||
|
|
|
@ -105,9 +105,8 @@ def testTileCompact():
|
|||
transform.YieldOp()
|
||||
# CHECK-LABEL: TEST: testTileCompact
|
||||
# CHECK: transform.sequence
|
||||
# CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile
|
||||
# CHECK-DAG: interchange = [0, 1]
|
||||
# CHECK-DAG: sizes = [4, 8]
|
||||
# CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 8]
|
||||
# CHECK: interchange = [0, 1]
|
||||
|
||||
|
||||
@run
|
||||
|
@ -122,9 +121,8 @@ def testTileAttributes():
|
|||
transform.YieldOp()
|
||||
# CHECK-LABEL: TEST: testTileAttributes
|
||||
# CHECK: transform.sequence
|
||||
# CHECK: structured.tile
|
||||
# CHECK-DAG: interchange = [0, 1]
|
||||
# CHECK-DAG: sizes = [4, 8]
|
||||
# CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 8]
|
||||
# CHECK: interchange = [0, 1]
|
||||
|
||||
|
||||
@run
|
||||
|
@ -136,9 +134,24 @@ def testTileZero():
|
|||
transform.YieldOp()
|
||||
# CHECK-LABEL: TEST: testTileZero
|
||||
# CHECK: transform.sequence
|
||||
# CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile
|
||||
# CHECK-DAG: interchange = [0, 1, 2, 3]
|
||||
# CHECK-DAG: sizes = [4, 0, 2, 0]
|
||||
# CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 0, 2, 0]
|
||||
# CHECK: interchange = [0, 1, 2, 3]
|
||||
|
||||
|
||||
@run
|
||||
def testTileDynamic():
|
||||
with_pdl = transform.WithPDLPatternsOp()
|
||||
with InsertionPoint(with_pdl.body):
|
||||
sequence = transform.SequenceOp(with_pdl.bodyTarget)
|
||||
with InsertionPoint(sequence.body):
|
||||
m1 = transform.PDLMatchOp(sequence.bodyTarget, "first")
|
||||
m2 = transform.PDLMatchOp(sequence.bodyTarget, "second")
|
||||
structured.TileOp(sequence.bodyTarget, sizes=[m1, 3, m2, 0])
|
||||
transform.YieldOp()
|
||||
# CHECK-LABEL: TEST: testTileDynamic
|
||||
# CHECK: %[[FIRST:.+]] = pdl_match
|
||||
# CHECK: %[[SECOND:.+]] = pdl_match
|
||||
# CHECK: %{{.+}}, %{{.+}}:3 = transform.structured.tile %{{.*}}[%[[FIRST]], 3, %[[SECOND]], 0]
|
||||
|
||||
|
||||
@run
|
||||
|
|
|
@ -7461,6 +7461,7 @@ cc_library(
|
|||
],
|
||||
includes = ["include"],
|
||||
deps = [
|
||||
":ArithmeticDialect",
|
||||
":IR",
|
||||
":LinalgDialect",
|
||||
":LinalgTransformOpsIncGen",
|
||||
|
|
Loading…
Reference in New Issue