[mlir] Move `memref.dim` canonicalization using `InferShapedTypeOpInterface` to a separate pass.

Based on dicussion in
[this](https://llvm.discourse.group/t/remove-canonicalizer-for-memref-dim-via-shapedtypeopinterface/3641)
thread the pattern to resolve the `memref.dim` of a value that is a
result of an operation that implements the
`InferShapedTypeOpInterface` is moved to a separate pass instead of
running it as a canonicalization pass. This allows shape resolution to
happen when explicitly required, instead of automatically through a
canonicalization.

Differential Revision: https://reviews.llvm.org/D104321
This commit is contained in:
MaheshRavishankar 2021-06-16 22:12:16 -07:00
parent 838490de7e
commit 3ed3e438a7
15 changed files with 635 additions and 407 deletions

View File

@ -16,6 +16,15 @@
#include "mlir/Pass/Pass.h"
namespace mlir {
class AffineDialect;
namespace tensor {
class TensorDialect;
} // namespace tensor
namespace vector {
class VectorDialect;
} // namespace vector
namespace memref {
//===----------------------------------------------------------------------===//
@ -26,6 +35,11 @@ namespace memref {
/// into `patterns`.
void populateFoldSubViewOpPatterns(RewritePatternSet &patterns);
/// Appends patterns that resolve `memref.dim` operations with values that are
/// defined by operations that implement the `InferShapedTypeOpInterface`, in
/// terms of shapes of its input operands.
void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns);
//===----------------------------------------------------------------------===//
// Passes
//===----------------------------------------------------------------------===//
@ -34,6 +48,11 @@ void populateFoldSubViewOpPatterns(RewritePatternSet &patterns);
/// load/store ops into `patterns`.
std::unique_ptr<Pass> createFoldSubViewOpsPass();
/// Creates an operation pass to resolve `memref.dim` operations with values
/// that are defined by operations that implement the
/// `InferShapedTypeOpInterface`, in terms of shapes of its input operands.
std::unique_ptr<Pass> createResolveShapedTypeResultDimsPass();
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//

View File

@ -23,6 +23,18 @@ def FoldSubViewOps : Pass<"fold-memref-subview-ops"> {
];
}
def ResolveShapedTypeResultDims : Pass<"resolve-shaped-type-result-dims"> {
let summary = "Resolve memref.dim of result values";
let description = [{
The pass resolves memref.dim of result of operations that
implement the `InferShapedTypeOpInterface` in terms of shapes of
its operands.
}];
let constructor = "mlir::memref::createResolveShapedTypeResultDimsPass()";
let dependentDialects = [
"memref::MemRefDialect", "tensor::TensorDialect"
];
}
#endif // MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES

View File

@ -794,84 +794,12 @@ struct DimOfCastOp : public OpRewritePattern<DimOp> {
return success();
}
};
/// Helper method to get the `Value` that is the shape of the `resultIdx`-th
/// result at dimension `dimIndex` from the `ShapedTypeOpInterface`.
/// TODO(ravishankarm): This is better put as a interface utility method
/// somewhere, but that would imply the interface will depend on the `tensor`
/// dialect. Ideally maybe a utility method in the `tensor` dialect.
static Value getResultDimFromShapeInterface(OpBuilder &builder, OpResult result,
int64_t dimIndex) {
unsigned resultNumber = result.getResultNumber();
auto shapedTypeOp = dyn_cast<InferShapedTypeOpInterface>(result.getOwner());
Location loc = result.getOwner()->getLoc();
if (!shapedTypeOp)
return nullptr;
// The interface exposes two methods, one that returns the shape of all the
// results as `Value` and other that returns the shape as a list of
// `SmallVector<Value>`. The former takes precedence over the latter. So first
// check if the op implements the first interface method or the second, and
// get the value to use appropriately.
SmallVector<Value> reifiedResultShapes;
if (succeeded(shapedTypeOp.reifyReturnTypeShapes(
builder, result.getOwner()->getOperands(), reifiedResultShapes))) {
if (reifiedResultShapes.size() <= resultNumber)
return nullptr;
Value resultShape = reifiedResultShapes[resultNumber];
auto resultShapeType = resultShape.getType().dyn_cast<RankedTensorType>();
if (!resultShapeType || !resultShapeType.getElementType().isa<IndexType>())
return nullptr;
return builder.create<tensor::ExtractOp>(
loc, resultShape, builder.createOrFold<ConstantIndexOp>(loc, dimIndex));
}
SmallVector<SmallVector<Value>> reifiedResultShapesPerDim;
if (failed(shapedTypeOp.reifyReturnTypeShapesPerResultDim(
builder, reifiedResultShapesPerDim)))
return nullptr;
if (reifiedResultShapesPerDim.size() <= resultNumber ||
reifiedResultShapesPerDim[resultNumber].size() !=
static_cast<size_t>(result.getType().cast<ShapedType>().getRank()))
return nullptr;
OpFoldResult valueOrAttr = reifiedResultShapesPerDim[resultNumber][dimIndex];
if (auto attr = valueOrAttr.dyn_cast<Attribute>())
return builder.createOrFold<ConstantIndexOp>(
loc, attr.cast<IntegerAttr>().getInt());
return valueOrAttr.get<Value>();
}
/// Fold dim of an operation that implements the InferShapedTypeOpInterface
struct DimOfShapedTypeOpInterface : public OpRewritePattern<DimOp> {
using OpRewritePattern<DimOp>::OpRewritePattern;
LogicalResult matchAndRewrite(DimOp dimOp,
PatternRewriter &rewriter) const override {
OpResult dimValue = dimOp.memrefOrTensor().dyn_cast<OpResult>();
if (!dimValue)
return failure();
auto shapedTypeOp =
dyn_cast<InferShapedTypeOpInterface>(dimValue.getOwner());
if (!shapedTypeOp)
return failure();
Optional<int64_t> dimIndex = dimOp.getConstantIndex();
if (!dimIndex)
return failure();
Value replacement =
getResultDimFromShapeInterface(rewriter, dimValue, *dimIndex);
if (!replacement)
return failure();
rewriter.replaceOp(dimOp, replacement);
return success();
}
};
} // end anonymous namespace.
void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<DimOfMemRefReshape, DimOfCastOp<BufferCastOp>,
DimOfCastOp<tensor::CastOp>, DimOfShapedTypeOpInterface>(context);
DimOfCastOp<tensor::CastOp>>(context);
}
// ---------------------------------------------------------------------------

View File

@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRMemRefTransforms
FoldSubViewOps.cpp
ResolveShapedTypeResultDims.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/MemRef
@ -9,9 +10,11 @@ add_mlir_dialect_library(MLIRMemRefTransforms
LINK_LIBS PUBLIC
MLIRAffine
MLIRInferTypeOpInterface
MLIRMemRef
MLIRPass
MLIRStandard
MLIRTensor
MLIRTransforms
MLIRVector
)

View File

@ -0,0 +1,127 @@
//===- ResolveShapedTypeResultDims.cpp - Resolve memref.dim ops of result values
//-------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This pass resolves `memref.dim` operations of result values in terms of
// shapes of their operands using the `InferShapedTypeOpInterface`.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
/// Helper method to get the `Value` that is the shape of the `resultIdx`-th
/// result at dimension `dimIndex` from the `ShapedTypeOpInterface`.
/// TODO(ravishankarm): This is better put as a interface utility method
/// somewhere, but that would imply the interface will depend on the `tensor`
/// dialect. Ideally maybe a utility method in the `tensor` dialect.
static Value getResultDimFromShapeInterface(OpBuilder &builder, OpResult result,
int64_t dimIndex) {
unsigned resultNumber = result.getResultNumber();
auto shapedTypeOp = dyn_cast<InferShapedTypeOpInterface>(result.getOwner());
Location loc = result.getOwner()->getLoc();
if (!shapedTypeOp)
return nullptr;
// The interface exposes two methods, one that returns the shape of all the
// results as `Value` and other that returns the shape as a list of
// `SmallVector<Value>`. The former takes precedence over the latter. So first
// check if the op implements the first interface method or the second, and
// get the value to use appropriately.
SmallVector<Value> reifiedResultShapes;
if (succeeded(shapedTypeOp.reifyReturnTypeShapes(
builder, result.getOwner()->getOperands(), reifiedResultShapes))) {
if (reifiedResultShapes.size() <= resultNumber)
return nullptr;
Value resultShape = reifiedResultShapes[resultNumber];
auto resultShapeType = resultShape.getType().dyn_cast<RankedTensorType>();
if (!resultShapeType || !resultShapeType.getElementType().isa<IndexType>())
return nullptr;
return builder.create<tensor::ExtractOp>(
loc, resultShape, builder.createOrFold<ConstantIndexOp>(loc, dimIndex));
}
SmallVector<SmallVector<Value>> reifiedResultShapesPerDim;
if (failed(shapedTypeOp.reifyReturnTypeShapesPerResultDim(
builder, reifiedResultShapesPerDim)))
return nullptr;
if (reifiedResultShapesPerDim.size() <= resultNumber ||
reifiedResultShapesPerDim[resultNumber].size() !=
static_cast<size_t>(result.getType().cast<ShapedType>().getRank()))
return nullptr;
OpFoldResult valueOrAttr = reifiedResultShapesPerDim[resultNumber][dimIndex];
if (auto attr = valueOrAttr.dyn_cast<Attribute>())
return builder.createOrFold<ConstantIndexOp>(
loc, attr.cast<IntegerAttr>().getInt());
return valueOrAttr.get<Value>();
}
namespace {
/// Fold dim of an operation that implements the InferShapedTypeOpInterface
struct DimOfShapedTypeOpInterface : public OpRewritePattern<memref::DimOp> {
using OpRewritePattern<memref::DimOp>::OpRewritePattern;
LogicalResult matchAndRewrite(memref::DimOp dimOp,
PatternRewriter &rewriter) const override {
OpResult dimValue = dimOp.memrefOrTensor().dyn_cast<OpResult>();
if (!dimValue)
return failure();
auto shapedTypeOp =
dyn_cast<InferShapedTypeOpInterface>(dimValue.getOwner());
if (!shapedTypeOp)
return failure();
Optional<int64_t> dimIndex = dimOp.getConstantIndex();
if (!dimIndex)
return failure();
Value replacement =
getResultDimFromShapeInterface(rewriter, dimValue, *dimIndex);
if (!replacement)
return failure();
rewriter.replaceOp(dimOp, replacement);
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Pass registration
//===----------------------------------------------------------------------===//
namespace {
#define GEN_PASS_CLASSES
#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
struct ResolveShapedTypeResultDimsPass final
: public ResolveShapedTypeResultDimsBase<ResolveShapedTypeResultDimsPass> {
void runOnOperation() override;
};
} // namespace
void memref::populateResolveShapedTypeResultDimsPatterns(
RewritePatternSet &patterns) {
patterns.add<DimOfShapedTypeOpInterface>(patterns.getContext());
}
void ResolveShapedTypeResultDimsPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
memref::populateResolveShapedTypeResultDimsPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(getOperation()->getRegions(),
std::move(patterns))))
return signalPassFailure();
}
std::unique_ptr<Pass> memref::createResolveShapedTypeResultDimsPass() {
return std::make_unique<ResolveShapedTypeResultDimsPass>();
}

View File

@ -532,205 +532,6 @@ func @init_tensor_canonicalize() -> (tensor<4x5x?xf32>) {
// -----
func @init_tensor_static_dim() -> (index, index) {
%c0 = constant 0 : index
%c2 = constant 2 : index
%c6 = constant 6 : index
%0 = linalg.init_tensor [4, 5, %c6] : tensor<4x5x?xf32>
%1 = memref.dim %0, %c2 : tensor<4x5x?xf32>
%2 = memref.dim %0, %c0 : tensor<4x5x?xf32>
return %1, %2 : index, index
}
// CHECK: func @init_tensor_static_dim
// CHECK-DAG: %[[C4:.+]] = constant 4 : index
// CHECK-DAG: %[[C6:.+]] = constant 6 : index
// CHECK: return %[[C6]], %[[C4]]
// -----
func @init_tensor_dynamic_dim(%arg0 : index) -> (index) {
%c2 = constant 2 : index
%0 = linalg.init_tensor [4, 5, %arg0] : tensor<4x5x?xf32>
%1 = memref.dim %0, %c2 : tensor<4x5x?xf32>
return %1 : index
}
// CHECK: func @init_tensor_dynamic_dim
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
// CHECK: return %[[ARG0]]
// -----
func @init_tensor_dynamic_dim2(%arg0 : index, %arg1 : index) -> (index, index) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%0 = linalg.init_tensor [%arg0, %arg1] : tensor<?x?xf32>
%1 = memref.dim %0, %c0 : tensor<?x?xf32>
%2 = memref.dim %0, %c1 : tensor<?x?xf32>
return %1, %2 : index, index
}
// CHECK: func @init_tensor_dynamic_dim2
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK: return %[[ARG0]], %[[ARG1]]
// -----
func @remove_dim_result_uses
(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
%arg2 : tensor<?x?xf32>) -> (index, index) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%0 = linalg.generic
{indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
affine_map<(d0, d1, d2) -> (d2, d1)>,
affine_map<(d0, d1, d2) -> (d0 + d1, d1 - d0)>],
iterator_types = ["parallel", "parallel", "reduction"]}
ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%arg2 : tensor<?x?xf32>) {
^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
%1 = mulf %arg3, %arg4 : f32
%2 = addf %1, %arg5 : f32
linalg.yield %2 : f32
} -> tensor<?x?xf32>
%3 = memref.dim %0, %c0 : tensor<?x?xf32>
%4 = memref.dim %0, %c1 : tensor<?x?xf32>
return %3, %4 : index, index
}
// CHECK: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
// CHECK: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (-s0 + s1)>
// CHECK: func @remove_dim_result_uses
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
// CHECK-DAG: %[[T0:.+]] = memref.dim %[[ARG0]], %[[C0]]
// CHECK-DAG: %[[T1:.+]] = memref.dim %[[ARG1]], %[[C1]]
// CHECK: %[[T2:.+]] = affine.apply #[[MAP0]]()[%[[T0]], %[[T1]]]
// CHECK-DAG: %[[T3:.+]] = memref.dim %[[ARG0]], %[[C0]]
// CHECK-DAG: %[[T4:.+]] = memref.dim %[[ARG1]], %[[C1]]
// CHECK: %[[T5:.+]] = affine.apply #[[MAP1]]()[%[[T3]], %[[T4]]]
// CHECK: return %[[T2]], %[[T5]]
// -----
func @remove_dim_result_uses_outs
(%arg0 : tensor<?xf32>, %arg1 : index) -> (index) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%d0 = memref.dim %arg0, %c0 : tensor<?xf32>
%0 = linalg.init_tensor [%d0, %arg1] : tensor<?x?xf32>
%1 = linalg.generic
{indexing_maps = [affine_map<(d0, d1) -> (d0)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%arg0 : tensor<?xf32>) outs(%0 : tensor<?x?xf32>) {
^bb0(%arg2: f32, %arg3: f32) :
linalg.yield %arg2 : f32
} -> tensor<?x?xf32>
%2 = memref.dim %1, %c1 : tensor<?x?xf32>
return %2 : index
}
// CHECK: func @remove_dim_result_uses_outs
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK: return %[[ARG1]]
// -----
func @remove_dim_result_uses_sequence
(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
%arg2 : tensor<?x?xf32>) -> (index, index, index, index) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
%1 = memref.dim %0, %c0 : tensor<?x?xf32>
%2 = memref.dim %0, %c1 : tensor<?x?xf32>
%3 = linalg.generic
{indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0)>,
affine_map<(d0, d1, d2) -> (d0, d2)>,
affine_map<(d0, d1, d2) -> (d0, d2)>],
iterator_types = ["parallel", "reduction", "parallel"]}
ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%0 : tensor<?x?xf32>) {
^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
%4 = mulf %arg3, %arg4 : f32
%5 = addf %4, %arg5 : f32
linalg.yield %5 : f32
} -> tensor<?x?xf32>
%6 = memref.dim %3, %c0 : tensor<?x?xf32>
%7 = memref.dim %3, %c1 : tensor<?x?xf32>
return %1, %2, %6, %7 : index, index, index, index
}
// CHECK-LABEL: func @remove_dim_result_uses_sequence
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
// CHECK-DAG: %[[T0:.+]] = memref.dim %[[ARG0]], %[[C0]]
// CHECK-DAG: %[[T1:.+]] = memref.dim %[[ARG1]], %[[C1]]
// CHECK-DAG: %[[T2:.+]] = memref.dim %[[ARG0]], %[[C1]]
// CHECK-DAG: %[[T3:.+]] = memref.dim %[[ARG1]], %[[C1]]
// CHECK: return %[[T0]], %[[T1]], %[[T2]], %[[T3]]
// -----
func @keep_result_dim_uses_sequence2
(%arg0 : tensor<?xf32>, %arg1 : index) -> (index, index) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%d0 = memref.dim %arg0, %c0 : tensor<?xf32>
%0 = linalg.init_tensor [%d0, %arg1] : tensor<?x?xf32>
%1 = linalg.generic
{indexing_maps = [affine_map<(d0, d1) -> (d0)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%arg0 : tensor<?xf32>) outs(%0 : tensor<?x?xf32>) {
^bb0(%arg2: f32, %arg3 : f32):
linalg.yield %arg2 : f32
} -> tensor<?x?xf32>
%2 = memref.dim %1, %c0 : tensor<?x?xf32>
%3 = memref.dim %1, %c1 : tensor<?x?xf32>
return %2, %3 : index, index
}
// CHECK: func @keep_result_dim_uses_sequence2
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
// CHECK-DAG: %[[T0:.+]] = memref.dim %[[ARG0]], %[[C0]]
// CHECK: return %[[T0]], %[[ARG1]]
// -----
#map = affine_map<(d0) -> (d0)>
func @init_tensor_dim_of_linalg_result(%arg_0 : tensor<?xf32>,
%arg_1: tensor<?xf32>) -> (index, index) {
%0, %1 = linalg.generic {
indexing_maps = [#map, #map, #map],
iterator_types = ["parallel"]
} ins(%arg_0 : tensor<?xf32>)
outs(%arg_0, %arg_1 : tensor<?xf32>, tensor<?xf32>) {
^bb0(%in: f32, %out_0: f32, %out_1: f32):
linalg.yield %in, %in : f32, f32
} -> (tensor<?xf32>, tensor<?xf32>)
%c0 = constant 0 : index
%num_elem_0 = memref.dim %0, %c0 : tensor<?xf32>
%num_elem_1 = memref.dim %1, %c0 : tensor<?xf32>
return %num_elem_0, %num_elem_1 : index, index
}
// CHECK: func @init_tensor_dim_of_linalg_result(
// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<?xf32>
// CHECK-SAME: %[[ARG_1:[a-zA-Z0-9_]+]]: tensor<?xf32>)
// CHECK: %[[R0:.+]] = memref.dim %[[ARG_0]]
// CHECK: %[[R1:.+]] = memref.dim %[[ARG_0]]
// CHECK: return %[[R0]], %[[R1]]
// -----
func @init_tensor_reshape_expansion(%arg0 : index) -> tensor<2x3x5x4x?x7xf32> {
%0 = linalg.init_tensor [6, 5, %arg0] : tensor<6x5x?xf32>
%1 = linalg.tensor_expand_shape %0 [[0, 1], [2], [3, 4, 5]]
@ -740,9 +541,12 @@ func @init_tensor_reshape_expansion(%arg0 : index) -> tensor<2x3x5x4x?x7xf32> {
// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
// CHECK: func @init_tensor_reshape_expansion
// CHECK-SAME: %[[ARG0:.+]]: index
// CHECK: %[[T0:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]]
// CHECK: %[[T1:.+]] = linalg.init_tensor [2, 3, 5, 4, %[[T0]], 7]
// CHECK: return %[[T1]]
// CHECK: %[[C2:.+]] = constant 2
// CHECK: %[[INIT1:.+]] = linalg.init_tensor [6, 5, %[[ARG0]]]
// CHECK: %[[D0:.+]] = memref.dim %[[INIT1]], %[[C2]]
// CHECK: %[[T0:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
// CHECK: %[[INIT2:.+]] = linalg.init_tensor [2, 3, 5, 4, %[[T0]], 7]
// CHECK: return %[[INIT2]]
// -----
@ -755,9 +559,12 @@ func @init_tensor_reshape_collapse(%arg0 : index) -> tensor<6x5x?xf32> {
// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 28)>
// CHECK: func @init_tensor_reshape_collapse
// CHECK-SAME: %[[ARG0:.+]]: index
// CHECK: %[[T0:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]]
// CHECK: %[[T1:.+]] = linalg.init_tensor [6, 5, %[[T0]]]
// CHECK: return %[[T1]]
// CHECK: %[[C4:.+]] = constant 4
// CHECK: %[[INIT1:.+]] = linalg.init_tensor [2, 3, 5, 4, %[[ARG0]], 7]
// CHECK: %[[D0:.+]] = memref.dim %[[INIT1]], %[[C4]]
// CHECK: %[[T0:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
// CHECK: %[[INIT2:.+]] = linalg.init_tensor [6, 5, %[[T0]]]
// CHECK: return %[[INIT2]]
// -----
@ -906,54 +713,6 @@ func @pad_tensor_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
} : tensor<5x6xf32> to tensor<5x6xf32>
return %0 : tensor<5x6xf32>
}
// -----
func @dim_reshape_expansion(%arg0 : tensor<6x5x?xf32>) -> (index, index, index)
{
%c1 = constant 1 : index
%c3 = constant 3 : index
%c4 = constant 4 : index
%0 = linalg.tensor_expand_shape %arg0 [[0, 1], [2], [3, 4, 5]]
: tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32>
%1 = memref.dim %0, %c1 : tensor<2x3x5x4x?x7xf32>
%2 = memref.dim %0, %c3 : tensor<2x3x5x4x?x7xf32>
%3 = memref.dim %0, %c4 : tensor<2x3x5x4x?x7xf32>
return %1, %2, %3 : index, index, index
}
// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
// CHECK: func @dim_reshape_expansion
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<6x5x?xf32>
// CHECK-DAG: %[[C2:.+]] = constant 2 : index
// CHECK-DAG: %[[C3:.+]] = constant 3 : index
// CHECK-DAG: %[[C4:.+]] = constant 4 : index
// CHECK: %[[D0:.+]] = memref.dim %[[ARG0]], %[[C2]]
// CHECK: %[[D1:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
// CHECK: return %[[C3]], %[[C4]], %[[D1]]
// -----
func @dim_reshape_collapse(%arg0 : tensor<2x3x5x4x?x7xf32>) -> (index, index)
{
%c1 = constant 1 : index
%c2 = constant 2 : index
%0 = linalg.tensor_collapse_shape %arg0 [[0, 1], [2], [3, 4, 5]]
: tensor<2x3x5x4x?x7xf32> into tensor<6x5x?xf32>
%1 = memref.dim %0, %c1 : tensor<6x5x?xf32>
%2 = memref.dim %0, %c2 : tensor<6x5x?xf32>
return %1, %2 : index, index
}
// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 28)>
// CHECK: func @dim_reshape_collapse
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x3x5x4x?x7xf32>
// CHECK-DAG: %[[C4:.+]] = constant 4 : index
// CHECK-DAG: %[[C5:.+]] = constant 5 : index
// CHECK: %[[D0:.+]] = memref.dim %[[ARG0]], %[[C4]]
// CHECK: %[[D1:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
// CHECK: return %[[C5]], %[[D1]]
// -----
func @propogate_casts(%arg0 : tensor<?x?xf32>, %arg1 : f32, %arg2 : index,
%arg3 : index) -> tensor<?x?xf32> {
%c0 = constant 0 : index
@ -1083,41 +842,6 @@ func @fold_tiled_loop_inputs(%A: memref<192xf32>, %A_tensor: tensor<192xf32>,
// -----
func @dim_of_pad_op(%arg0 : tensor<2x?x?xf32>, %arg1 : index, %arg2 : index,
%arg3: f32) -> (index, index, index)
{
%c0 = constant 0 : index
%c1 = constant 1 : index
%c2 = constant 2 : index
%c3 = constant 3 : index
%c4 = constant 4 : index
%c5 = constant 5 : index
%0 = linalg.pad_tensor %arg0 low[%c3, %arg1, %c4] high[7, %c5, %arg2] {
^bb0(%arg4: index, %arg5: index, %arg6: index):
linalg.yield %arg3 : f32
} : tensor<2x?x?xf32> to tensor<?x?x?xf32>
%1 = memref.dim %0, %c0 : tensor<?x?x?xf32>
%2 = memref.dim %0, %c1 : tensor<?x?x?xf32>
%3 = memref.dim %0, %c2 : tensor<?x?x?xf32>
return %1, %2, %3 : index, index, index
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 5)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 4)>
// CHECK: func @dim_of_pad_op
// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]+]]: tensor<2x?x?xf32>
// CHECK-SAME: %[[ARG1:[A-Za-z0-9_]+]]: index
// CHECK-SAME: %[[ARG2:[A-Za-z0-9_]+]]: index
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
// CHECK-DAG: %[[C2:.+]] = constant 2 : index
// CHECK-DAG: %[[C12:.+]] = constant 12 : index
// CHECK: %[[IN_DIM1:.+]] = memref.dim %[[ARG0]], %[[C1]]
// CHECK: %[[OUT_DIM1:.+]] = affine.apply #[[MAP0]]()[%[[ARG1]], %[[IN_DIM1]]]
// CHECK: %[[IN_DIM2:.+]] = memref.dim %[[ARG0]], %[[C2]]
// CHECK: %[[OUT_DIM2:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[IN_DIM2]]]
// CHECK: return %[[C12]], %[[OUT_DIM1]], %[[OUT_DIM2]]
// -----
#map = affine_map<(d0, d1) -> (d0, d1)>
func @indexed_generic(%arg0: memref<?x?xindex>, %arg1: memref<?x?xindex>) {

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt -pass-pipeline="func(test-linalg-tile-and-fuse{tile-sizes=16,32,64}),canonicalize,cse" -split-input-file %s | FileCheck %s
// RUN: mlir-opt -pass-pipeline="func(test-linalg-tile-and-fuse{tile-sizes=16,32,64}),resolve-shaped-type-result-dims,canonicalize,cse" -split-input-file %s | FileCheck %s
module {
func @three_op_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,

View File

@ -1,5 +1,5 @@
// RUN: mlir-opt %s -test-linalg-tensor-fusion-transform-patterns -canonicalize -cse --split-input-file | FileCheck %s
// RUN: mlir-opt %s -test-linalg-tiled-loop-fusion-transform-patterns -canonicalize -cse --split-input-file | FileCheck %s --check-prefix=TLOOP
// RUN: mlir-opt %s -test-linalg-tensor-fusion-transform-patterns -resolve-shaped-type-result-dims -canonicalize -cse --split-input-file | FileCheck %s
// RUN: mlir-opt %s -test-linalg-tiled-loop-fusion-transform-patterns -resolve-shaped-type-result-dims -canonicalize -cse --split-input-file | FileCheck %s --check-prefix=TLOOP
module {
func @matmul_fusion(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,

View File

@ -0,0 +1,278 @@
// RUN: mlir-opt -resolve-shaped-type-result-dims -split-input-file %s | FileCheck %s
func @init_tensor_static_dim() -> (index, index) {
%c0 = constant 0 : index
%c2 = constant 2 : index
%c6 = constant 6 : index
%0 = linalg.init_tensor [4, 5, %c6] : tensor<4x5x?xf32>
%1 = memref.dim %0, %c2 : tensor<4x5x?xf32>
%2 = memref.dim %0, %c0 : tensor<4x5x?xf32>
return %1, %2 : index, index
}
// CHECK: func @init_tensor_static_dim
// CHECK-DAG: %[[C4:.+]] = constant 4 : index
// CHECK-DAG: %[[C6:.+]] = constant 6 : index
// CHECK: return %[[C6]], %[[C4]]
// -----
func @init_tensor_dynamic_dim(%arg0 : index) -> (index) {
%c2 = constant 2 : index
%0 = linalg.init_tensor [4, 5, %arg0] : tensor<4x5x?xf32>
%1 = memref.dim %0, %c2 : tensor<4x5x?xf32>
return %1 : index
}
// CHECK: func @init_tensor_dynamic_dim
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
// CHECK: return %[[ARG0]]
// -----
func @init_tensor_dynamic_dim2(%arg0 : index, %arg1 : index) -> (index, index) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%0 = linalg.init_tensor [%arg0, %arg1] : tensor<?x?xf32>
%1 = memref.dim %0, %c0 : tensor<?x?xf32>
%2 = memref.dim %0, %c1 : tensor<?x?xf32>
return %1, %2 : index, index
}
// CHECK: func @init_tensor_dynamic_dim2
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK: return %[[ARG0]], %[[ARG1]]
// -----
func @remove_dim_result_uses
(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
%arg2 : tensor<?x?xf32>) -> (index, index) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%0 = linalg.generic
{indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
affine_map<(d0, d1, d2) -> (d2, d1)>,
affine_map<(d0, d1, d2) -> (d0 + d1, d1 - d0)>],
iterator_types = ["parallel", "parallel", "reduction"]}
ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%arg2 : tensor<?x?xf32>) {
^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
%1 = mulf %arg3, %arg4 : f32
%2 = addf %1, %arg5 : f32
linalg.yield %2 : f32
} -> tensor<?x?xf32>
%3 = memref.dim %0, %c0 : tensor<?x?xf32>
%4 = memref.dim %0, %c1 : tensor<?x?xf32>
return %3, %4 : index, index
}
// CHECK: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
// CHECK: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s1 - s0)>
// CHECK: func @remove_dim_result_uses
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
// CHECK-DAG: %[[T0:.+]] = memref.dim %[[ARG0]], %[[C0]]
// CHECK-DAG: %[[T1:.+]] = memref.dim %[[ARG1]], %[[C1]]
// CHECK: %[[T2:.+]] = affine.apply #[[MAP0]]()[%[[T0]], %[[T1]]]
// CHECK-DAG: %[[T3:.+]] = memref.dim %[[ARG0]], %[[C0]]
// CHECK-DAG: %[[T4:.+]] = memref.dim %[[ARG1]], %[[C1]]
// CHECK: %[[T5:.+]] = affine.apply #[[MAP1]]()[%[[T3]], %[[T4]]]
// CHECK: return %[[T2]], %[[T5]]
// -----
func @remove_dim_result_uses_outs
(%arg0 : tensor<?xf32>, %arg1 : index) -> (index) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%d0 = memref.dim %arg0, %c0 : tensor<?xf32>
%0 = linalg.init_tensor [%d0, %arg1] : tensor<?x?xf32>
%1 = linalg.generic
{indexing_maps = [affine_map<(d0, d1) -> (d0)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%arg0 : tensor<?xf32>) outs(%0 : tensor<?x?xf32>) {
^bb0(%arg2: f32, %arg3: f32) :
linalg.yield %arg2 : f32
} -> tensor<?x?xf32>
%2 = memref.dim %1, %c1 : tensor<?x?xf32>
return %2 : index
}
// CHECK: func @remove_dim_result_uses_outs
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK: return %[[ARG1]]
// -----
func @remove_dim_result_uses_sequence
(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
%arg2 : tensor<?x?xf32>) -> (index, index, index, index) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
%1 = memref.dim %0, %c0 : tensor<?x?xf32>
%2 = memref.dim %0, %c1 : tensor<?x?xf32>
%3 = linalg.generic
{indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0)>,
affine_map<(d0, d1, d2) -> (d0, d2)>,
affine_map<(d0, d1, d2) -> (d0, d2)>],
iterator_types = ["parallel", "reduction", "parallel"]}
ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%0 : tensor<?x?xf32>) {
^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
%4 = mulf %arg3, %arg4 : f32
%5 = addf %4, %arg5 : f32
linalg.yield %5 : f32
} -> tensor<?x?xf32>
%6 = memref.dim %3, %c0 : tensor<?x?xf32>
%7 = memref.dim %3, %c1 : tensor<?x?xf32>
return %1, %2, %6, %7 : index, index, index, index
}
// CHECK-LABEL: func @remove_dim_result_uses_sequence
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
// CHECK-DAG: %[[T0:.+]] = memref.dim %[[ARG0]], %[[C0]]
// CHECK-DAG: %[[T1:.+]] = memref.dim %[[ARG1]], %[[C1]]
// CHECK-DAG: %[[T2:.+]] = memref.dim %[[ARG0]], %[[C1]]
// CHECK-DAG: %[[T3:.+]] = memref.dim %[[ARG1]], %[[C1]]
// CHECK: return %[[T0]], %[[T1]], %[[T2]], %[[T3]]
// -----
func @keep_result_dim_uses_sequence2
(%arg0 : tensor<?xf32>, %arg1 : index) -> (index, index) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%d0 = memref.dim %arg0, %c0 : tensor<?xf32>
%0 = linalg.init_tensor [%d0, %arg1] : tensor<?x?xf32>
%1 = linalg.generic
{indexing_maps = [affine_map<(d0, d1) -> (d0)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%arg0 : tensor<?xf32>) outs(%0 : tensor<?x?xf32>) {
^bb0(%arg2: f32, %arg3 : f32):
linalg.yield %arg2 : f32
} -> tensor<?x?xf32>
%2 = memref.dim %1, %c0 : tensor<?x?xf32>
%3 = memref.dim %1, %c1 : tensor<?x?xf32>
return %2, %3 : index, index
}
// CHECK: func @keep_result_dim_uses_sequence2
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
// CHECK-DAG: %[[T0:.+]] = memref.dim %[[ARG0]], %[[C0]]
// CHECK: return %[[T0]], %[[ARG1]]
// -----
#map = affine_map<(d0) -> (d0)>
func @init_tensor_dim_of_linalg_result(%arg_0 : tensor<?xf32>,
%arg_1: tensor<?xf32>) -> (index, index) {
%0, %1 = linalg.generic {
indexing_maps = [#map, #map, #map],
iterator_types = ["parallel"]
} ins(%arg_0 : tensor<?xf32>)
outs(%arg_0, %arg_1 : tensor<?xf32>, tensor<?xf32>) {
^bb0(%in: f32, %out_0: f32, %out_1: f32):
linalg.yield %in, %in : f32, f32
} -> (tensor<?xf32>, tensor<?xf32>)
%c0 = constant 0 : index
%num_elem_0 = memref.dim %0, %c0 : tensor<?xf32>
%num_elem_1 = memref.dim %1, %c0 : tensor<?xf32>
return %num_elem_0, %num_elem_1 : index, index
}
// CHECK: func @init_tensor_dim_of_linalg_result(
// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<?xf32>
// CHECK-SAME: %[[ARG_1:[a-zA-Z0-9_]+]]: tensor<?xf32>)
// CHECK: %[[R0:.+]] = memref.dim %[[ARG_0]]
// CHECK: %[[R1:.+]] = memref.dim %[[ARG_0]]
// CHECK: return %[[R0]], %[[R1]]
// -----
func @dim_reshape_expansion(%arg0 : tensor<6x5x?xf32>) -> (index, index, index)
{
%c1 = constant 1 : index
%c3 = constant 3 : index
%c4 = constant 4 : index
%0 = linalg.tensor_expand_shape %arg0 [[0, 1], [2], [3, 4, 5]]
: tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32>
%1 = memref.dim %0, %c1 : tensor<2x3x5x4x?x7xf32>
%2 = memref.dim %0, %c3 : tensor<2x3x5x4x?x7xf32>
%3 = memref.dim %0, %c4 : tensor<2x3x5x4x?x7xf32>
return %1, %2, %3 : index, index, index
}
// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
// CHECK: func @dim_reshape_expansion
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<6x5x?xf32>
// CHECK-DAG: %[[C2:.+]] = constant 2 : index
// CHECK-DAG: %[[C3:.+]] = constant 3 : index
// CHECK-DAG: %[[C4:.+]] = constant 4 : index
// CHECK: %[[D0:.+]] = memref.dim %[[ARG0]], %[[C2]]
// CHECK: %[[D1:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
// CHECK: return %[[C3]], %[[C4]], %[[D1]]
// -----
func @dim_reshape_collapse(%arg0 : tensor<2x3x5x4x?x7xf32>) -> (index, index)
{
%c1 = constant 1 : index
%c2 = constant 2 : index
%0 = linalg.tensor_collapse_shape %arg0 [[0, 1], [2], [3, 4, 5]]
: tensor<2x3x5x4x?x7xf32> into tensor<6x5x?xf32>
%1 = memref.dim %0, %c1 : tensor<6x5x?xf32>
%2 = memref.dim %0, %c2 : tensor<6x5x?xf32>
return %1, %2 : index, index
}
// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 28)>
// CHECK: func @dim_reshape_collapse
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x3x5x4x?x7xf32>
// CHECK-DAG: %[[C4:.+]] = constant 4 : index
// CHECK-DAG: %[[C5:.+]] = constant 5 : index
// CHECK: %[[D0:.+]] = memref.dim %[[ARG0]], %[[C4]]
// CHECK: %[[D1:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
// CHECK: return %[[C5]], %[[D1]]
// -----
func @dim_of_pad_op(%arg0 : tensor<2x?x?xf32>, %arg1 : index, %arg2 : index,
%arg3: f32) -> (index, index, index)
{
%c0 = constant 0 : index
%c1 = constant 1 : index
%c2 = constant 2 : index
%c3 = constant 3 : index
%c4 = constant 4 : index
%c5 = constant 5 : index
%0 = linalg.pad_tensor %arg0 low[%c3, %arg1, %c4] high[7, %c5, %arg2] {
^bb0(%arg4: index, %arg5: index, %arg6: index):
linalg.yield %arg3 : f32
} : tensor<2x?x?xf32> to tensor<?x?x?xf32>
%1 = memref.dim %0, %c0 : tensor<?x?x?xf32>
%2 = memref.dim %0, %c1 : tensor<?x?x?xf32>
%3 = memref.dim %0, %c2 : tensor<?x?x?xf32>
return %1, %2, %3 : index, index, index
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s1 + s0 + 5)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s1 + s0 + 4)>
// CHECK: func @dim_of_pad_op
// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]+]]: tensor<2x?x?xf32>
// CHECK-SAME: %[[ARG1:[A-Za-z0-9_]+]]: index
// CHECK-SAME: %[[ARG2:[A-Za-z0-9_]+]]: index
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
// CHECK-DAG: %[[C2:.+]] = constant 2 : index
// CHECK-DAG: %[[C12:.+]] = constant 12 : index
// CHECK: %[[IN_DIM1:.+]] = memref.dim %[[ARG0]], %[[C1]]
// CHECK: %[[OUT_DIM1:.+]] = affine.apply #[[MAP0]]()[%[[ARG1]], %[[IN_DIM1]]]
// CHECK: %[[IN_DIM2:.+]] = memref.dim %[[ARG0]], %[[C2]]
// CHECK: %[[OUT_DIM2:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[IN_DIM2]]]
// CHECK: return %[[C12]], %[[OUT_DIM1]], %[[OUT_DIM2]]

View File

@ -205,16 +205,14 @@ func @conv_tensors_dynamic(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?x
// CHECK: #[[BOUND8_MAP:.+]] = affine_map<(d0)[s0] -> (8, -d0 + s0)>
// CHECK: #[[BOUND8_MAP_2:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s0, 8, -d0 + s1)>
// CHECK: #[[BOUND8_MAP_3:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 8)>
// CHECK: #[[BOUND16_MAP:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)>
// CHECK: #[[X2_MAP:.+]] = affine_map<(d0) -> (d0 * 2)>
// CHECK: #[[INPUT_BOUND:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * 2 + s0 - 2, d1 * -2 + s1)>
// CHECK: #[[BOUND16_MAP_2:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 16)>
// CHECK: #[[BOUND16_MAP_2:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s0, 16, -d0 + s1)>
// CHECK: #[[BOUND4_MAP:.+]] = affine_map<(d0)[s0] -> (4, -d0 + s0)>
// CHECK: #[[BOUND2_MAP:.+]] = affine_map<(d0)[s0] -> (2, -d0 + s0)>
// CHECK: #[[BOUND4_MAP_2:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 4)>
// CHECK: #[[BOUND4_MAP_2:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s0, 4, -d0 + s1)>
// CHECK: #[[BOUND2_MAP_2:.+]] = affine_map<(d0, d1)[s0, s1] -> (-d0 + s0, 2, -d1 + s1)>
// CHECK: #[[BOUND2_MAP_3:.+]] = affine_map<(d0, d1)[s0] -> (-d0 + s0, 2, -d1 + s0)>
// CHECK: func @conv_tensors_dynamic
// CHECK-SAME: (%[[INPUT]]: tensor<?x?x?x?xf32>, %[[FILTER]]: tensor<?x?x?x?xf32>, %[[ELEM]]: tensor<?x?x?x?xf32>)
@ -240,16 +238,20 @@ func @conv_tensors_dynamic(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?x
// CHECK-DAG: %[[INPUT_C:.+]] = memref.dim %[[INPUT]], %[[C3]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[FILTER_IC:.+]] = memref.dim %[[FILTER]], %[[C2]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[FILTER_OC:.+]] = memref.dim %[[FILTER]], %[[C3]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[FILL_N:.+]] = memref.dim %[[FILL]], %[[C0]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[FILL_H:.+]] = memref.dim %[[FILL]], %[[C1]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[FILL_W:.+]] = memref.dim %[[FILL]], %[[C2]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[FILL_C:.+]] = memref.dim %[[FILL]], %[[C3]] : tensor<?x?x?x?xf32>
// CHECK: scf.for %[[IV0:.+]] = %{{.+}} to %[[ELEM_N]] step %{{.+}} iter_args(%{{.+}} = %[[FILL]])
// CHECK-NEXT: %[[SIZE_ELEM_N:.+]] = affine.min #[[BOUND8_MAP]](%[[IV0]])[%[[ELEM_N]]]
// CHECK-NEXT: %[[SIZE_INPUT_N:.+]] = affine.min #[[BOUND8_MAP_2]](%[[IV0]])[%[[INPUT_N]], %[[ELEM_N]]]
// CHECK-NEXT: %[[SIZE_ELEM_N_2:.+]] = affine.min #[[BOUND8_MAP_3]](%[[IV0]])[%[[ELEM_N]]]
// CHECK-NEXT: %[[SIZE_ELEM_N_2:.+]] = affine.min #[[BOUND8_MAP_2]](%[[IV0]])[%[[FILL_N]], %[[ELEM_N]]]
// CHECK-NEXT: scf.for %[[IV1:.+]] = %{{.+}} to %[[ELEM_OH]]
// CHECK-NEXT: %[[SIZE_ELEM_OH:.+]] = affine.min #[[BOUND16_MAP]](%[[IV1]])[%[[ELEM_OH]]]
// CHECK-NEXT: %[[OFFSET_OH:.+]] = affine.apply #[[X2_MAP]](%[[IV1]])
// CHECK-NEXT: %[[SIZE_INPUT_H:.+]] = affine.min #[[INPUT_BOUND]](%[[SIZE_ELEM_OH]], %[[IV1]])[%[[FILTER_H]], %[[INPUT_H]]]
// CHECK-NEXT: %[[SIZE_ELEM_OH_2:.+]] = affine.min #[[BOUND16_MAP_2]](%[[IV1]])[%[[ELEM_OH]]]
// CHECK-NEXT: %[[SIZE_ELEM_OH_2:.+]] = affine.min #[[BOUND16_MAP_2]](%[[IV1]])[%[[FILL_H]], %[[ELEM_OH]]]
// CHECK-NEXT: scf.for %[[IV2:.+]] = %{{.+}} to %[[ELEM_OW]]
// CHECK-NEXT: %[[SIZE_ELEM_OW:.+]] = affine.min #[[BOUND4_MAP]](%[[IV2]])[%[[ELEM_OW]]]
// CHECK-NEXT: %[[SIZE_ELEM_OC:.+]] = affine.min #[[BOUND2_MAP]](%[[IV2]])[%[[ELEM_OC]]]
@ -257,7 +259,7 @@ func @conv_tensors_dynamic(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?x
// CHECK-NEXT: %[[SIZE_INPUT_W:.+]] = affine.min #[[INPUT_BOUND]](%[[SIZE_ELEM_OW]], %[[IV2]])[%[[FILTER_W]], %[[INPUT_W]]]
// CHECK-NEXT: %[[ST_INPUT:.+]] = subtensor %[[INPUT]][%[[IV0]], %[[OFFSET_OH]], %[[OFFSET_OW]], 0]
// CHECK-SAME: [%[[SIZE_INPUT_N]], %[[SIZE_INPUT_H]], %[[SIZE_INPUT_W]], %[[INPUT_C]]]
// CHECK-NEXT: %[[SIZE_ELEM_OW_2:.+]] = affine.min #[[BOUND4_MAP_2]](%[[IV2]])[%[[ELEM_OW]]]
// CHECK-NEXT: %[[SIZE_ELEM_OW_2:.+]] = affine.min #[[BOUND4_MAP_2]](%[[IV2]])[%[[FILL_W]], %[[ELEM_OW]]]
// CHECK-NEXT: scf.for %[[IV3:.+]] = %{{.+}} to %[[ELEM_OC]] step %{{.+}} iter_args(%[[ARG:[a-z0-9]+]]
// CHECK-NEXT: %[[ST_ELEM:.+]] = subtensor %[[ELEM]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
// CHECK-SAME: [%[[SIZE_ELEM_N]], %[[SIZE_ELEM_OH]], %[[SIZE_ELEM_OW]], %[[SIZE_ELEM_OC]]]
@ -266,7 +268,7 @@ func @conv_tensors_dynamic(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?x
// CHECK-NEXT: %[[SIZE_ELEM_OC_2:.+]] = affine.min #[[BOUND2_MAP_2]](%[[IV3]], %[[IV2]])[%[[FILTER_OC]], %[[ELEM_OC]]]
// CHECK-NEXT: %[[ST_FILTER:.+]] = subtensor %[[FILTER]][0, 0, 0, %[[IV3]]]
// CHECK-SAME: [%[[FILTER_H]], %[[FILTER_W]], %[[FILTER_IC]], %[[SIZE_ELEM_OC_2]]]
// CHECK-NEXT: %[[SIZE_ELEM_OC_3:.+]] = affine.min #[[BOUND2_MAP_3]](%[[IV3]], %[[IV2]])[%[[ELEM_OC]]]
// CHECK-NEXT: %[[SIZE_ELEM_OC_3:.+]] = affine.min #[[BOUND2_MAP_2]](%[[IV3]], %[[IV2]])[%[[FILL_C]], %[[ELEM_OC]]]
// CHECK-NEXT: %[[ST_FILL:.+]] = subtensor %[[FILL]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
// CHECK-SAME: [%[[SIZE_ELEM_N_2]], %[[SIZE_ELEM_OH_2]], %[[SIZE_ELEM_OW_2]], %[[SIZE_ELEM_OC_3]]]
// CHECK-NEXT: %[[ST_CONV:.+]] = linalg.conv_2d_input_nhwc_filter_hwcf

View File

@ -0,0 +1,88 @@
// RUN: mlir-opt %s -resolve-shaped-type-result-dims -split-input-file | FileCheck %s
func @result_shape(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
-> (index, index, index, index, index) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%c2 = constant 2 : index
%0:2 = "test.op_with_result_shape_interface"(%arg0, %arg1)
: (tensor<2x3x?xf32>, tensor<?x5xf32>) -> (tensor<?x5xf32>, tensor<2x3x?xf32>)
%1 = memref.dim %0#0, %c0 : tensor<?x5xf32>
%2 = memref.dim %0#0, %c1 : tensor<?x5xf32>
%3 = memref.dim %0#1, %c0 : tensor<2x3x?xf32>
%4 = memref.dim %0#1, %c1 : tensor<2x3x?xf32>
%5 = memref.dim %0#1, %c2 : tensor<2x3x?xf32>
return %1, %2, %3, %4, %5 : index, index, index, index, index
}
// CHECK-LABEL: func @result_shape(
// CHECK-SAME: %[[ARG_0:[a-z0-9]*]]: tensor<2x3x?xf32>
// CHECK-SAME: %[[ARG_1:[a-z0-9]*]]: tensor<?x5xf32>)
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
// CHECK-DAG: %[[C2:.+]] = constant 2 : index
// CHECK-DAG: %[[C3:.+]] = constant 3 : index
// CHECK-DAG: %[[C5:.+]] = constant 5 : index
// CHECK-DAG: %[[D0:.+]] = memref.dim %[[ARG_1]], %[[C0]]
// CHECK-DAG: %[[S0:.+]] = tensor.from_elements %[[D0]], %[[C5]]
// CHECK-DAG: %[[D0_OUT:.+]] = tensor.extract %[[S0]][%[[C0]]]
// CHECK-DAG: %[[D1:.+]] = memref.dim %[[ARG_0]], %[[C2]]
// CHECK-DAG: %[[S1:.+]] = tensor.from_elements %[[C2]], %[[C3]], %[[D1]]
// CHECK-DAG: %[[D1_OUT:.+]] = tensor.extract %[[S1]][%[[C2]]]
// CHECK: return %[[D0_OUT]], %[[C5]], %[[C2]], %[[C3]], %[[D1_OUT]]
// -----
func @result_shape_per_dim(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
-> (index, index, index, index, index) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%c2 = constant 2 : index
%0:2 = "test.op_with_result_shape_per_dim_interface"(%arg0, %arg1)
: (tensor<2x3x?xf32>, tensor<?x5xf32>) -> (tensor<?x5xf32>, tensor<2x3x?xf32>)
%1 = memref.dim %0#0, %c0 : tensor<?x5xf32>
%2 = memref.dim %0#0, %c1 : tensor<?x5xf32>
%3 = memref.dim %0#1, %c0 : tensor<2x3x?xf32>
%4 = memref.dim %0#1, %c1 : tensor<2x3x?xf32>
%5 = memref.dim %0#1, %c2 : tensor<2x3x?xf32>
return %1, %2, %3, %4, %5 : index, index, index, index, index
}
// CHECK-LABEL: func @result_shape_per_dim(
// CHECK-SAME: %[[ARG_0:[a-z0-9]*]]: tensor<2x3x?xf32>
// CHECK-SAME: %[[ARG_1:[a-z0-9]*]]: tensor<?x5xf32>)
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
// CHECK-DAG: %[[C2:.+]] = constant 2 : index
// CHECK-DAG: %[[C3:.+]] = constant 3 : index
// CHECK-DAG: %[[C5:.+]] = constant 5 : index
// CHECK-DAG: %[[D0:.+]] = memref.dim %[[ARG_1]], %[[C0]]
// CHECK-DAG: %[[D1:.+]] = memref.dim %[[ARG_0]], %[[C2]]
// CHECK: return %[[D0]], %[[C5]], %[[C2]], %[[C3]], %[[D1]]
// -----
func @result_shape_and_per_dim(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
-> (index, index, index, index, index) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%c2 = constant 2 : index
%0:2 = "test.op_with_result_shape_and_per_dim_interface"(%arg0, %arg1)
: (tensor<2x3x?xf32>, tensor<?x5xf32>) -> (tensor<?x5xf32>, tensor<2x3x?xf32>)
%1 = memref.dim %0#0, %c0 : tensor<?x5xf32>
%2 = memref.dim %0#0, %c1 : tensor<?x5xf32>
%3 = memref.dim %0#1, %c0 : tensor<2x3x?xf32>
%4 = memref.dim %0#1, %c1 : tensor<2x3x?xf32>
%5 = memref.dim %0#1, %c2 : tensor<2x3x?xf32>
return %1, %2, %3, %4, %5 : index, index, index, index, index
}
// CHECK-LABEL: func @result_shape_and_per_dim(
// CHECK-SAME: %[[ARG_0:[a-z0-9]*]]: tensor<2x3x?xf32>
// CHECK-SAME: %[[ARG_1:[a-z0-9]*]]: tensor<?x5xf32>)
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
// CHECK-DAG: %[[C2:.+]] = constant 2 : index
// CHECK-DAG: %[[C3:.+]] = constant 3 : index
// CHECK-DAG: %[[C5:.+]] = constant 5 : index
// CHECK-DAG: %[[D0:.+]] = memref.dim %[[ARG_1]], %[[C0]]
// CHECK-DAG: %[[S0:.+]] = tensor.from_elements %[[D0]], %[[C5]]
// CHECK-DAG: %[[D0_OUT:.+]] = tensor.extract %[[S0]][%[[C0]]]
// CHECK-DAG: %[[D1:.+]] = memref.dim %[[ARG_0]], %[[C2]]
// CHECK-DAG: %[[S1:.+]] = tensor.from_elements %[[C2]], %[[C3]], %[[D1]]
// CHECK-DAG: %[[D1_OUT:.+]] = tensor.extract %[[S1]][%[[C2]]]
// CHECK: return %[[D0_OUT]], %[[C5]], %[[C2]], %[[C3]], %[[D1_OUT]]

View File

@ -82,30 +82,6 @@ func @typemismatch() -> i32 {
return %0 : i32
}
// CHECK-LABEL: func @result_shape_per_dim
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<2x3x?xf32>, %[[ARG_1:[a-z0-9]*]]: tensor<?x5xf32>)
func @result_shape_per_dim(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
-> (index, index, index, index, index) {
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
// CHECK-DAG: %[[C2:.+]] = constant 2 : index
// CHECK-DAG: %[[C3:.+]] = constant 3 : index
// CHECK-DAG: %[[C5:.+]] = constant 5 : index
%c0 = constant 0 : index
%c1 = constant 1 : index
%c2 = constant 2 : index
%0:2 = "test.op_with_result_shape_per_dim_interface"(%arg0, %arg1)
: (tensor<2x3x?xf32>, tensor<?x5xf32>) -> (tensor<?x5xf32>, tensor<2x3x?xf32>)
%1 = memref.dim %0#0, %c0 : tensor<?x5xf32>
%2 = memref.dim %0#0, %c1 : tensor<?x5xf32>
%3 = memref.dim %0#1, %c0 : tensor<2x3x?xf32>
%4 = memref.dim %0#1, %c1 : tensor<2x3x?xf32>
%5 = memref.dim %0#1, %c2 : tensor<2x3x?xf32>
// CHECK-DAG: %[[D0:.+]] = memref.dim %[[ARG_1]], %[[C0]]
// CHECK-DAG: %[[D1:.+]] = memref.dim %[[ARG_0]], %[[C2]]
// CHECK: return %[[D0]], %[[C5]], %[[C2]], %[[C3]], %[[D1]]
return %1, %2, %3, %4, %5 : index, index, index, index, index
}
// CHECK-LABEL: test_dialect_canonicalizer
func @test_dialect_canonicalizer() -> (i32) {
%0 = "test.dialect_canonicalizable"() : () -> (i32)

View File

@ -65,6 +65,7 @@ add_mlir_library(MLIRTestDialect
MLIRReduce
MLIRStandard
MLIRStandardOpsTransforms
MLIRTensor
MLIRTransformUtils
MLIRTransforms
)

View File

@ -13,6 +13,7 @@
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/PatternMatch.h"
@ -802,22 +803,75 @@ LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
return success();
}
LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
OpBuilder &builder, ValueRange operands,
llvm::SmallVectorImpl<Value> &shapes) {
Location loc = getLoc();
shapes.reserve(operands.size());
for (Value operand : llvm::reverse(operands)) {
auto currShape = llvm::to_vector<4>(llvm::map_range(
llvm::seq<int64_t>(
0, operand.getType().cast<RankedTensorType>().getRank()),
[&](int64_t dim) -> Value {
return builder.createOrFold<memref::DimOp>(loc, operand, dim);
}));
shapes.push_back(builder.create<tensor::FromElementsOp>(
getLoc(), builder.getIndexType(), currShape));
}
return success();
}
LogicalResult
OpWithResultShapePerDimInterfaceOp ::reifyReturnTypeShapesPerResultDim(
OpBuilder &builder,
llvm::SmallVectorImpl<llvm::SmallVector<Value>> &shapes) {
SmallVector<Value> operand1Shape, operand2Shape;
Location loc = getLoc();
for (auto i :
llvm::seq<int>(0, operand1().getType().cast<ShapedType>().getRank())) {
operand1Shape.push_back(builder.create<memref::DimOp>(loc, operand1(), i));
shapes.reserve(getNumOperands());
for (Value operand : llvm::reverse(getOperands())) {
auto currShape = llvm::to_vector<4>(llvm::map_range(
llvm::seq<int64_t>(
0, operand.getType().cast<RankedTensorType>().getRank()),
[&](int64_t dim) -> Value {
return builder.createOrFold<memref::DimOp>(loc, operand, dim);
}));
shapes.emplace_back(std::move(currShape));
}
for (auto i :
llvm::seq<int>(0, operand2().getType().cast<ShapedType>().getRank())) {
operand2Shape.push_back(builder.create<memref::DimOp>(loc, operand2(), i));
return success();
}
LogicalResult OpWithResultShapeAndPerDimInterfaceOp::reifyReturnTypeShapes(
OpBuilder &builder, ValueRange operands,
llvm::SmallVectorImpl<Value> &shapes) {
Location loc = getLoc();
shapes.reserve(operands.size());
for (Value operand : llvm::reverse(operands)) {
auto currShape = llvm::to_vector<4>(llvm::map_range(
llvm::seq<int64_t>(
0, operand.getType().cast<RankedTensorType>().getRank()),
[&](int64_t dim) -> Value {
return builder.createOrFold<memref::DimOp>(loc, operand, dim);
}));
shapes.push_back(builder.create<tensor::FromElementsOp>(
getLoc(), builder.getIndexType(), currShape));
}
return success();
}
LogicalResult
OpWithResultShapeAndPerDimInterfaceOp ::reifyReturnTypeShapesPerResultDim(
OpBuilder &builder,
llvm::SmallVectorImpl<llvm::SmallVector<Value>> &shapes) {
Location loc = getLoc();
shapes.reserve(getNumOperands());
for (Value operand : llvm::reverse(getOperands())) {
auto currShape = llvm::to_vector<4>(llvm::map_range(
llvm::seq<int64_t>(
0, operand.getType().cast<RankedTensorType>().getRank()),
[&](int64_t dim) -> Value {
return builder.createOrFold<memref::DimOp>(loc, operand, dim);
}));
shapes.emplace_back(std::move(currShape));
}
shapes.emplace_back(std::move(operand2Shape));
shapes.emplace_back(std::move(operand1Shape));
return success();
}

View File

@ -571,9 +571,25 @@ def OpWithShapedTypeInferTypeInterfaceOp : TEST_Op<"op_with_shaped_type_infer_ty
let results = (outs AnyTensor);
}
def OpWithResultShapePerDimInterfaceOp : TEST_Op<"op_with_result_shape_per_dim_interface",
def OpWithResultShapeInterfaceOp : TEST_Op<"op_with_result_shape_interface",
[DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["reifyReturnTypeShapesPerResultDim"]>]> {
["reifyReturnTypeShapes"]>]> {
let arguments = (ins AnyRankedTensor:$operand1, AnyRankedTensor:$operand2);
let results = (outs AnyRankedTensor:$result1, AnyRankedTensor:$result2);
}
def OpWithResultShapePerDimInterfaceOp :
TEST_Op<"op_with_result_shape_per_dim_interface",
[DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["reifyReturnTypeShapesPerResultDim"]>]> {
let arguments = (ins AnyRankedTensor:$operand1, AnyRankedTensor:$operand2);
let results = (outs AnyRankedTensor:$result1, AnyRankedTensor:$result2);
}
def OpWithResultShapeAndPerDimInterfaceOp :
TEST_Op<"op_with_result_shape_and_per_dim_interface",
[DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["reifyReturnTypeShapes", "reifyReturnTypeShapesPerResultDim"]>]> {
let arguments = (ins AnyRankedTensor:$operand1, AnyRankedTensor:$operand2);
let results = (outs AnyRankedTensor:$result1, AnyRankedTensor:$result2);
}