forked from OSchip/llvm-project
[mlir][TilingInterface] Add a method to generate scalar implementation of the op.
While The tiling interface provides a mechanism for operations to be tiled into tiled version of the op (or another op at the same level of abstraction), the `generateScalarImplementation` method added here is the "exit point" after all transformations have been done. Ops that implement this method are expected to generate IR that are directly lowerable to backend dialects like LLVM or SPIR-V dialects. Differential Revision: https://reviews.llvm.org/D130612
This commit is contained in:
parent
1e15e24a76
commit
6f03a10e4f
|
@ -141,6 +141,23 @@ private:
|
|||
TileUsingSCFForOp tilingPattern;
|
||||
};
|
||||
|
||||
/// Pattern to lower operations that implement the `TilingInterface` to
|
||||
/// loops/scalar IR using `scf.for`.
|
||||
struct LowerToLoopsUsingSCFForOp
|
||||
: public OpInterfaceRewritePattern<TilingInterface> {
|
||||
using OpInterfaceRewritePattern<TilingInterface>::OpInterfaceRewritePattern;
|
||||
|
||||
/// `matchAndRewrite` implementation that returns the significant transformed
|
||||
/// pieces of IR.
|
||||
FailureOr<SmallVector<scf::ForOp>>
|
||||
returningMatchAndRewrite(TilingInterface op, PatternRewriter &rewriter) const;
|
||||
|
||||
LogicalResult matchAndRewrite(TilingInterface op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
return returningMatchAndRewrite(op, rewriter);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace scf
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -48,7 +48,7 @@ OpFoldResult getAsOpFoldResult(Value val);
|
|||
|
||||
/// Given an array of values, try to extract a constant Attribute from each
|
||||
/// value. If this fails, return the original value.
|
||||
SmallVector<OpFoldResult> getAsOpFoldResult(ArrayRef<Value> values);
|
||||
SmallVector<OpFoldResult> getAsOpFoldResult(ValueRange values);
|
||||
|
||||
/// Convert `arrayAttr` to a vector of OpFoldResult.
|
||||
SmallVector<OpFoldResult> getAsOpFoldResult(ArrayAttr arrayAttr);
|
||||
|
|
|
@ -167,6 +167,28 @@ def TilingInterface : OpInterface<"TilingInterface"> {
|
|||
/*defaultImplementation=*/[{
|
||||
return failure();
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Generates the scalar implementation of the operation.
|
||||
|
||||
Given the list `ivs` that represent points in the iteration space
|
||||
(as specified by `getIterationDomain()`) returns the scalar operations
|
||||
that represent the computation at that point in the iteration space.
|
||||
This method is typically used as the "exit path", i.e. once all
|
||||
transformations are done, this method can be used to lower to scalar
|
||||
code that can then be lowered to LLVM or SPIR-V dialects.
|
||||
}],
|
||||
/*retType=*/"LogicalResult",
|
||||
/*methodName=*/"generateScalarImplementation",
|
||||
/*args=*/(ins
|
||||
"OpBuilder &":$b,
|
||||
"Location ":$loc,
|
||||
"ValueRange ":$ivs),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
return failure();
|
||||
}]
|
||||
>
|
||||
];
|
||||
}
|
||||
|
|
|
@ -13,14 +13,68 @@
|
|||
#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Utils/StaticValueUtils.h"
|
||||
#include "mlir/Interfaces/TilingInterface.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::linalg;
|
||||
|
||||
namespace {
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Utility methods for implementation of Tiling Interface for Linalg ops
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Return the SSA values that represent the data point accessed using a given
|
||||
/// `indexingMap` for a given point in the iteration space represented by `ivs`.
|
||||
static SmallVector<Value> getIndicesForAccess(OpBuilder &b, Location loc,
|
||||
AffineMap indexingMap,
|
||||
ValueRange ivs) {
|
||||
SmallVector<Value> indices;
|
||||
indices.reserve(indexingMap.getNumResults());
|
||||
for (auto result : indexingMap.getResults()) {
|
||||
AffineMap m = AffineMap::get(indexingMap.getNumDims(),
|
||||
indexingMap.getNumSymbols(), result);
|
||||
Value v = b.create<AffineApplyOp>(loc, m, ivs);
|
||||
indices.push_back(v);
|
||||
}
|
||||
return indices;
|
||||
}
|
||||
|
||||
/// Method to inline the payload of a `linalgOp` given the iteration space
|
||||
/// point and values for the arguments of the payload.
|
||||
static LogicalResult inlinePayload(OpBuilder &b, LinalgOp linalgOp,
|
||||
ValueRange ivs, ValueRange argValues) {
|
||||
Block *body = linalgOp.getBlock();
|
||||
BlockAndValueMapping map;
|
||||
map.map(body->getArguments(), argValues);
|
||||
for (auto &op : body->without_terminator()) {
|
||||
if (auto indexOp = dyn_cast<IndexOp>(&op)) {
|
||||
map.map(indexOp.getResult(), ivs[indexOp.dim()]);
|
||||
continue;
|
||||
}
|
||||
b.clone(op, map);
|
||||
}
|
||||
|
||||
Operation *terminator = body->getTerminator();
|
||||
Location loc = terminator->getLoc();
|
||||
for (auto operand : llvm::enumerate(terminator->getOperands())) {
|
||||
Value toStore = map.lookupOrDefault(operand.value());
|
||||
OpOperand *storeInto = linalgOp.getOutputOperand(operand.index());
|
||||
auto indices = getIndicesForAccess(
|
||||
b, loc, linalgOp.getTiedIndexingMap(storeInto), ivs);
|
||||
b.create<memref::StoreOp>(loc, toStore,
|
||||
linalgOp.getOutputOperand(operand.index())->get(),
|
||||
indices);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// External Model for implementing `TilingInterface` for `LinalgOp`s.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
/// External model implementation of TilingInterface for LinalgOps. An external
|
||||
/// model implementation is used for now till the use of `TilingInterface` is
|
||||
/// on-par with the current Linalg tiling + fusion patterns. Once it is
|
||||
|
@ -167,6 +221,38 @@ struct LinalgOpTilingInterface
|
|||
|
||||
return tiledOp[0]->getResult(resultNumber);
|
||||
}
|
||||
|
||||
LogicalResult generateScalarImplementation(Operation *op, OpBuilder &builder,
|
||||
Location loc,
|
||||
ValueRange ivs) const {
|
||||
auto linalgOp = cast<LinalgOp>(op);
|
||||
if (!linalgOp.hasBufferSemantics())
|
||||
return op->emitOpError("expected operation to have buffer semantics");
|
||||
|
||||
SmallVector<Value> indexedValues;
|
||||
indexedValues.reserve(linalgOp.getNumInputsAndOutputs());
|
||||
Location linalgOpLoc = op->getLoc();
|
||||
/// Load the data corresponding to the block arguments that
|
||||
/// represent input operands.
|
||||
for (OpOperand *operand : linalgOp.getInputAndOutputOperands()) {
|
||||
if (!linalgOp.payloadUsesValueFromOperand(operand)) {
|
||||
indexedValues.push_back(nullptr);
|
||||
continue;
|
||||
}
|
||||
if (linalgOp.isScalar(operand)) {
|
||||
indexedValues.push_back(operand->get());
|
||||
continue;
|
||||
}
|
||||
SmallVector<Value> indices = getIndicesForAccess(
|
||||
builder, linalgOpLoc, linalgOp.getTiedIndexingMap(operand), ivs);
|
||||
Value load =
|
||||
builder.create<memref::LoadOp>(linalgOpLoc, operand->get(), indices);
|
||||
indexedValues.push_back(load);
|
||||
}
|
||||
|
||||
/// Inline the op payload and store the result.
|
||||
return inlinePayload(builder, linalgOp, ivs, indexedValues);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
|
|
@ -494,3 +494,41 @@ scf::TileConsumerAndFuseProducersUsingSCFForOp::returningMatchAndRewrite(
|
|||
tileAndFuseResult.loops.back(), rewriter);
|
||||
return tileAndFuseResult;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// LowerToLoopsUsingSCFForOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
FailureOr<SmallVector<scf::ForOp>>
|
||||
scf::LowerToLoopsUsingSCFForOp::returningMatchAndRewrite(
|
||||
TilingInterface op, PatternRewriter &rewriter) const {
|
||||
SmallVector<Range> domain = op.getIterationDomain(rewriter);
|
||||
|
||||
// TODO: Handle cases where the op has results if needed.
|
||||
if (op->getNumResults() > 0) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unable to lower to loops operations with return values");
|
||||
}
|
||||
|
||||
SmallVector<Value> ivs;
|
||||
SmallVector<scf::ForOp> loops;
|
||||
Location loc = op.getLoc();
|
||||
for (auto loopRange : domain) {
|
||||
Value offsetVal =
|
||||
getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
|
||||
Value sizeVal =
|
||||
getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
|
||||
Value strideVal =
|
||||
getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride);
|
||||
auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal,
|
||||
strideVal, ValueRange{});
|
||||
loops.push_back(loop);
|
||||
ivs.push_back(loop.getInductionVar());
|
||||
rewriter.setInsertionPoint(loop.getBody()->getTerminator());
|
||||
}
|
||||
if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
|
||||
return failure();
|
||||
}
|
||||
rewriter.eraseOp(op);
|
||||
return loops;
|
||||
}
|
||||
|
|
|
@ -62,7 +62,7 @@ OpFoldResult getAsOpFoldResult(Value val) {
|
|||
|
||||
/// Given an array of values, try to extract a constant Attribute from each
|
||||
/// value. If this fails, return the original value.
|
||||
SmallVector<OpFoldResult> getAsOpFoldResult(ArrayRef<Value> values) {
|
||||
SmallVector<OpFoldResult> getAsOpFoldResult(ValueRange values) {
|
||||
return llvm::to_vector<4>(
|
||||
llvm::map_range(values, [](Value v) { return getAsOpFoldResult(v); }));
|
||||
}
|
||||
|
|
|
@ -0,0 +1,161 @@
|
|||
// RUN: mlir-opt -test-tiling-interface=lower-to-scalar-using-scf-for -split-input-file %s | FileCheck %s
|
||||
|
||||
func.func @gemm(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>,
|
||||
%arg2 : memref<?x?xf32>) {
|
||||
linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
|
||||
outs(%arg2 : memref<?x?xf32>)
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @gemm
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf32>
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<?x?xf32>
|
||||
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: memref<?x?xf32>
|
||||
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
|
||||
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
|
||||
// CHECK-DAG: %[[M:.+]] = memref.dim %[[ARG0]], %[[C0]]
|
||||
// CHECK-DAG: %[[K:.+]] = memref.dim %[[ARG0]], %[[C1]]
|
||||
// CHECK-DAG: %[[N:.+]] = memref.dim %[[ARG1]], %[[C1]]
|
||||
// CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C1]]
|
||||
// CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C1]]
|
||||
// CHECK: scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[K]] step %[[C1]]
|
||||
// CHECK-DAG: %[[LHS:.+]] = memref.load %[[ARG0]][%[[IV0]], %[[IV2]]]
|
||||
// CHECK-DAG: %[[RHS:.+]] = memref.load %[[ARG1]][%[[IV2]], %[[IV1]]]
|
||||
// CHECK-DAG: %[[OUT:.+]] = memref.load %[[ARG2]][%[[IV0]], %[[IV1]]]
|
||||
// CHECK: %[[MULF:.+]] = arith.mulf %[[LHS]], %[[RHS]]
|
||||
// CHECK: %[[ADDF:.+]] = arith.addf %[[OUT]], %[[MULF]]
|
||||
// CHECK: memref.store %[[ADDF]], %[[ARG2]][%[[IV0]], %[[IV1]]]
|
||||
|
||||
// -----
|
||||
|
||||
func.func @indexed_generic(%arg0 : memref<200x300xi32>, %arg1 : memref<300xi16>,
|
||||
%arg2 : memref<200xi8>, %arg3 : memref<300x200xi64>) {
|
||||
linalg.generic {
|
||||
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>,
|
||||
affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d1, d0)>],
|
||||
iterator_types = ["parallel", "parallel"]}
|
||||
ins(%arg0, %arg1, %arg2 : memref<200x300xi32>, memref<300xi16>, memref<200xi8>)
|
||||
outs(%arg3 : memref<300x200xi64>) {
|
||||
^bb0(%b0 : i32, %b1 : i16, %b2 : i8, %b3 : i64):
|
||||
%0 = linalg.index 0 : index
|
||||
%1 = arith.index_cast %0 : index to i16
|
||||
%2 = arith.muli %b1, %1 : i16
|
||||
%3 = linalg.index 1 : index
|
||||
%4 = arith.index_cast %3 : index to i8
|
||||
%5 = arith.muli %b2, %4 : i8
|
||||
%6 = arith.extsi %2 : i16 to i32
|
||||
%7 = arith.extsi %5 : i8 to i32
|
||||
%8 = arith.addi %6, %7 : i32
|
||||
%9 = arith.addi %8, %b0 : i32
|
||||
%10 = arith.extsi %9 : i32 to i64
|
||||
linalg.yield %10 : i64
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @indexed_generic
|
||||
// CHECK-SAME: %[[ARG0:.+]]: memref<200x300xi32>
|
||||
// CHECK-SAME: %[[ARG1:.+]]: memref<300xi16>
|
||||
// CHECK-SAME: %[[ARG2:.+]]: memref<200xi8>
|
||||
// CHECK-SAME: %[[ARG3:.+]]: memref<300x200xi64>
|
||||
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
|
||||
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
|
||||
// CHECK-DAG: %[[C200:.+]] = arith.constant 200 : index
|
||||
// CHECK-DAG: %[[C300:.+]] = arith.constant 300 : index
|
||||
// CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[C200]] step %[[C1]]
|
||||
// CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[C300]] step %[[C1]]
|
||||
// CHECK-DAG: %[[B0:.+]] = memref.load %[[ARG0]][%[[IV0]], %[[IV1]]]
|
||||
// CHECK-DAG: %[[B1:.+]] = memref.load %[[ARG1]][%[[IV1]]]
|
||||
// CHECK-DAG: %[[B2:.+]] = memref.load %[[ARG2]][%[[IV0]]]
|
||||
// CHECK: %[[T1:.+]] = arith.index_cast %[[IV0]]
|
||||
// CHECK: %[[T2:.+]] = arith.muli %[[B1]], %[[T1]]
|
||||
// CHECK: %[[T4:.+]] = arith.index_cast %[[IV1]]
|
||||
// CHECK: %[[T5:.+]] = arith.muli %[[B2]], %[[T4]]
|
||||
// CHECK: %[[T6:.+]] = arith.extsi %[[T2]]
|
||||
// CHECK: %[[T7:.+]] = arith.extsi %[[T5]]
|
||||
// CHECK: %[[T8:.+]] = arith.addi %[[T6]], %[[T7]]
|
||||
// CHECK: %[[T9:.+]] = arith.addi %[[T8]], %[[B0]]
|
||||
// CHECK: %[[T10:.+]] = arith.extsi %[[T9]]
|
||||
// CHECK: memref.store %[[T10]], %[[ARG3]][%[[IV1]], %[[IV0]]]
|
||||
|
||||
// -----
|
||||
|
||||
func.func @conv_strides_and_dilation(%arg0 : memref<?x?x?x?xf32>, %arg1 : memref<?x?x?x?xf32>,
|
||||
%arg2 : memref<?x?x?x?xf32>) {
|
||||
linalg.conv_2d_nhwc_hwcf {
|
||||
strides = dense<[1, 2]> : tensor<2xi64>,
|
||||
dilations = dense<[3, 4]> : tensor<2xi64>}
|
||||
ins(%arg0, %arg1 : memref<?x?x?x?xf32>, memref<?x?x?x?xf32>)
|
||||
outs(%arg2 : memref<?x?x?x?xf32>)
|
||||
return
|
||||
}
|
||||
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1 + d4 * 3)>
|
||||
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2 * 2 + d5 * 4)>
|
||||
// CHECK: func @conv_strides_and_dilation(
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
|
||||
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
|
||||
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
|
||||
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
|
||||
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
|
||||
// CHECK-DAG: %[[N:.+]] = memref.dim %[[ARG0]], %[[C0]]
|
||||
// CHECK-DAG: %[[C:.+]] = memref.dim %[[ARG0]], %[[C3]]
|
||||
// CHECK-DAG: %[[H:.+]] = memref.dim %[[ARG1]], %[[C0]]
|
||||
// CHECK-DAG: %[[W:.+]] = memref.dim %[[ARG1]], %[[C1]]
|
||||
// CHECK-DAG: %[[F:.+]] = memref.dim %[[ARG1]], %[[C3]]
|
||||
// CHECK-DAG: %[[P:.+]] = memref.dim %[[ARG2]], %[[C1]]
|
||||
// CHECK-DAG: %[[Q:.+]] = memref.dim %[[ARG2]], %[[C2]]
|
||||
// CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C1]]
|
||||
// CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[P]] step %[[C1]]
|
||||
// CHECK: scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[Q]] step %[[C1]]
|
||||
// CHECK: scf.for %[[IV3:[a-zA-Z0-9]+]] = %[[C0]] to %[[F]] step %[[C1]]
|
||||
// CHECK: scf.for %[[IV4:[a-zA-Z0-9]+]] = %[[C0]] to %[[H]] step %[[C1]]
|
||||
// CHECK: scf.for %[[IV5:[a-zA-Z0-9]+]] = %[[C0]] to %[[W]] step %[[C1]]
|
||||
// CHECK: scf.for %[[IV6:[a-zA-Z0-9]+]] = %[[C0]] to %[[C]] step %[[C1]]
|
||||
// CHECK-DAG: %[[I:.+]] = affine.apply #[[MAP0]](%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]], %[[IV5]], %[[IV6]])
|
||||
// CHECK-DAG: %[[J:.+]] = affine.apply #[[MAP1]](%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]], %[[IV5]], %[[IV6]])
|
||||
// CHECK-DAG: %[[T9:.+]] = memref.load %[[ARG0]][%[[IV0]], %[[I]], %[[J]], %[[IV6]]]
|
||||
// CHECK-DAG: %[[T10:.+]] = memref.load %[[ARG1]][%[[IV4]], %[[IV5]], %[[IV6]], %[[IV3]]]
|
||||
// CHECK-DAG: %[[T11:.+]] = memref.load %[[ARG2]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
|
||||
// CHECK: %[[T12:.+]] = arith.mulf %[[T9]], %[[T10]]
|
||||
// CHECK: %[[T13:.+]] = arith.addf %[[T11]], %[[T12]]
|
||||
// CHECK: memref.store %[[T13]], %[[ARG2]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
|
||||
|
||||
// -----
|
||||
|
||||
func.func @pool_strides_and_dilation(%arg0 : memref<?x?x?x?xf32>, %arg1 : memref<?x?xf32>,
|
||||
%arg2 : memref<?x?x?x?xf32>) {
|
||||
linalg.pooling_nhwc_max {
|
||||
strides = dense<[1, 2]> : tensor<2xi64>,
|
||||
dilations = dense<[3, 4]> : tensor<2xi64>}
|
||||
ins(%arg0, %arg1 : memref<?x?x?x?xf32>, memref<?x?xf32>)
|
||||
outs(%arg2 : memref<?x?x?x?xf32>)
|
||||
return
|
||||
}
|
||||
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1 + d4 * 3)>
|
||||
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2 * 2 + d5 * 4)>
|
||||
// CHECK: func @pool_strides_and_dilation
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<?x?xf32>
|
||||
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
|
||||
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
|
||||
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
|
||||
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
|
||||
// CHECK-DAG: %[[N:.+]] = memref.dim %[[ARG0]], %[[C0]]
|
||||
// CHECK-DAG: %[[C:.+]] = memref.dim %[[ARG0]], %[[C3]]
|
||||
// CHECK-DAG: %[[H:.+]] = memref.dim %[[ARG1]], %[[C0]]
|
||||
// CHECK-DAG: %[[W:.+]] = memref.dim %[[ARG1]], %[[C1]]
|
||||
// CHECK-DAG: %[[P:.+]] = memref.dim %[[ARG2]], %[[C1]]
|
||||
// CHECK-DAG: %[[Q:.+]] = memref.dim %[[ARG2]], %[[C2]]
|
||||
// CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C1]]
|
||||
// CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[P]] step %[[C1]]
|
||||
// CHECK: scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[Q]] step %[[C1]]
|
||||
// CHECK: scf.for %[[IV3:[a-zA-Z0-9]+]] = %[[C0]] to %[[C]] step %[[C1]]
|
||||
// CHECK: scf.for %[[IV4:[a-zA-Z0-9]+]] = %[[C0]] to %[[H]] step %[[C1]]
|
||||
// CHECK: scf.for %[[IV5:[a-zA-Z0-9]+]] = %[[C0]] to %[[W]] step %[[C1]]
|
||||
// CHECK-DAG: %[[I:.+]] = affine.apply #[[MAP0]](%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]], %[[IV5]])
|
||||
// CHECK-DAG: %[[J:.+]] = affine.apply #[[MAP1]](%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]], %[[IV5]])
|
||||
// CHECK-DAG: %[[T8:.+]] = memref.load %[[ARG0]][%[[IV0]], %[[I]], %[[J]], %[[IV3]]]
|
||||
// CHECK-DAG: %[[T9:.+]] = memref.load %[[ARG2]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
|
||||
// CHECK: %[[T10:.+]] = arith.maxf %[[T9]], %[[T8]]
|
||||
// CHECK: memref.store %[[T10]], %[[ARG2]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
|
|
@ -65,7 +65,7 @@ private:
|
|||
linalg::LinalgTransformationFilter filter;
|
||||
};
|
||||
|
||||
/// Pattern for testing `TileConsumerAndFUseProducersUsingSCFForOp` pattern
|
||||
/// Pattern for testing `TileConsumerAndFuseProducersUsingSCFForOp` pattern
|
||||
/// (that tiles and fuses operations using the `TilingInterface` with `scf.for`
|
||||
/// ops for iterating over the tiles) while using a `filter` to avoid recursive
|
||||
/// application.
|
||||
|
@ -138,6 +138,12 @@ struct TestTilingInterfacePass
|
|||
"with scf.for operations"),
|
||||
llvm::cl::init(false)};
|
||||
|
||||
Option<bool> testLoweringToScalar{
|
||||
*this, "lower-to-scalar-using-scf-for",
|
||||
llvm::cl::desc("Test lowering to scalar implementation using "
|
||||
"TilingInterface with scf.for operations"),
|
||||
llvm::cl::init(false)};
|
||||
|
||||
void runOnOperation() override;
|
||||
|
||||
private:
|
||||
|
@ -199,6 +205,9 @@ void TestTilingInterfacePass::addTestPatterns(MLIRContext *context,
|
|||
context, patterns, "gemm_sequence_fusion", {10});
|
||||
return;
|
||||
}
|
||||
if (testLoweringToScalar) {
|
||||
patterns.add<scf::LowerToLoopsUsingSCFForOp>(context);
|
||||
}
|
||||
}
|
||||
|
||||
void TestTilingInterfacePass::runOnOperation() {
|
||||
|
|
Loading…
Reference in New Issue