forked from OSchip/llvm-project
[mlir] Move `memref.dim` canonicalization using `InferShapedTypeOpInterface` to a separate pass.
Based on dicussion in [this](https://llvm.discourse.group/t/remove-canonicalizer-for-memref-dim-via-shapedtypeopinterface/3641) thread the pattern to resolve the `memref.dim` of a value that is a result of an operation that implements the `InferShapedTypeOpInterface` is moved to a separate pass instead of running it as a canonicalization pass. This allows shape resolution to happen when explicitly required, instead of automatically through a canonicalization. Differential Revision: https://reviews.llvm.org/D104321
This commit is contained in:
parent
838490de7e
commit
3ed3e438a7
|
@ -16,6 +16,15 @@
|
|||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class AffineDialect;
|
||||
namespace tensor {
|
||||
class TensorDialect;
|
||||
} // namespace tensor
|
||||
namespace vector {
|
||||
class VectorDialect;
|
||||
} // namespace vector
|
||||
|
||||
namespace memref {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -26,6 +35,11 @@ namespace memref {
|
|||
/// into `patterns`.
|
||||
void populateFoldSubViewOpPatterns(RewritePatternSet &patterns);
|
||||
|
||||
/// Appends patterns that resolve `memref.dim` operations with values that are
|
||||
/// defined by operations that implement the `InferShapedTypeOpInterface`, in
|
||||
/// terms of shapes of its input operands.
|
||||
void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Passes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -34,6 +48,11 @@ void populateFoldSubViewOpPatterns(RewritePatternSet &patterns);
|
|||
/// load/store ops into `patterns`.
|
||||
std::unique_ptr<Pass> createFoldSubViewOpsPass();
|
||||
|
||||
/// Creates an operation pass to resolve `memref.dim` operations with values
|
||||
/// that are defined by operations that implement the
|
||||
/// `InferShapedTypeOpInterface`, in terms of shapes of its input operands.
|
||||
std::unique_ptr<Pass> createResolveShapedTypeResultDimsPass();
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Registration
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -23,6 +23,18 @@ def FoldSubViewOps : Pass<"fold-memref-subview-ops"> {
|
|||
];
|
||||
}
|
||||
|
||||
def ResolveShapedTypeResultDims : Pass<"resolve-shaped-type-result-dims"> {
|
||||
let summary = "Resolve memref.dim of result values";
|
||||
let description = [{
|
||||
The pass resolves memref.dim of result of operations that
|
||||
implement the `InferShapedTypeOpInterface` in terms of shapes of
|
||||
its operands.
|
||||
}];
|
||||
let constructor = "mlir::memref::createResolveShapedTypeResultDimsPass()";
|
||||
let dependentDialects = [
|
||||
"memref::MemRefDialect", "tensor::TensorDialect"
|
||||
];
|
||||
}
|
||||
|
||||
#endif // MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES
|
||||
|
||||
|
|
|
@ -794,84 +794,12 @@ struct DimOfCastOp : public OpRewritePattern<DimOp> {
|
|||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Helper method to get the `Value` that is the shape of the `resultIdx`-th
|
||||
/// result at dimension `dimIndex` from the `ShapedTypeOpInterface`.
|
||||
/// TODO(ravishankarm): This is better put as a interface utility method
|
||||
/// somewhere, but that would imply the interface will depend on the `tensor`
|
||||
/// dialect. Ideally maybe a utility method in the `tensor` dialect.
|
||||
static Value getResultDimFromShapeInterface(OpBuilder &builder, OpResult result,
|
||||
int64_t dimIndex) {
|
||||
unsigned resultNumber = result.getResultNumber();
|
||||
auto shapedTypeOp = dyn_cast<InferShapedTypeOpInterface>(result.getOwner());
|
||||
Location loc = result.getOwner()->getLoc();
|
||||
if (!shapedTypeOp)
|
||||
return nullptr;
|
||||
|
||||
// The interface exposes two methods, one that returns the shape of all the
|
||||
// results as `Value` and other that returns the shape as a list of
|
||||
// `SmallVector<Value>`. The former takes precedence over the latter. So first
|
||||
// check if the op implements the first interface method or the second, and
|
||||
// get the value to use appropriately.
|
||||
SmallVector<Value> reifiedResultShapes;
|
||||
if (succeeded(shapedTypeOp.reifyReturnTypeShapes(
|
||||
builder, result.getOwner()->getOperands(), reifiedResultShapes))) {
|
||||
if (reifiedResultShapes.size() <= resultNumber)
|
||||
return nullptr;
|
||||
Value resultShape = reifiedResultShapes[resultNumber];
|
||||
auto resultShapeType = resultShape.getType().dyn_cast<RankedTensorType>();
|
||||
if (!resultShapeType || !resultShapeType.getElementType().isa<IndexType>())
|
||||
return nullptr;
|
||||
return builder.create<tensor::ExtractOp>(
|
||||
loc, resultShape, builder.createOrFold<ConstantIndexOp>(loc, dimIndex));
|
||||
}
|
||||
|
||||
SmallVector<SmallVector<Value>> reifiedResultShapesPerDim;
|
||||
if (failed(shapedTypeOp.reifyReturnTypeShapesPerResultDim(
|
||||
builder, reifiedResultShapesPerDim)))
|
||||
return nullptr;
|
||||
if (reifiedResultShapesPerDim.size() <= resultNumber ||
|
||||
reifiedResultShapesPerDim[resultNumber].size() !=
|
||||
static_cast<size_t>(result.getType().cast<ShapedType>().getRank()))
|
||||
return nullptr;
|
||||
OpFoldResult valueOrAttr = reifiedResultShapesPerDim[resultNumber][dimIndex];
|
||||
if (auto attr = valueOrAttr.dyn_cast<Attribute>())
|
||||
return builder.createOrFold<ConstantIndexOp>(
|
||||
loc, attr.cast<IntegerAttr>().getInt());
|
||||
return valueOrAttr.get<Value>();
|
||||
}
|
||||
|
||||
/// Fold dim of an operation that implements the InferShapedTypeOpInterface
|
||||
struct DimOfShapedTypeOpInterface : public OpRewritePattern<DimOp> {
|
||||
using OpRewritePattern<DimOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(DimOp dimOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
OpResult dimValue = dimOp.memrefOrTensor().dyn_cast<OpResult>();
|
||||
if (!dimValue)
|
||||
return failure();
|
||||
auto shapedTypeOp =
|
||||
dyn_cast<InferShapedTypeOpInterface>(dimValue.getOwner());
|
||||
if (!shapedTypeOp)
|
||||
return failure();
|
||||
|
||||
Optional<int64_t> dimIndex = dimOp.getConstantIndex();
|
||||
if (!dimIndex)
|
||||
return failure();
|
||||
Value replacement =
|
||||
getResultDimFromShapeInterface(rewriter, dimValue, *dimIndex);
|
||||
if (!replacement)
|
||||
return failure();
|
||||
rewriter.replaceOp(dimOp, replacement);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace.
|
||||
|
||||
void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results.add<DimOfMemRefReshape, DimOfCastOp<BufferCastOp>,
|
||||
DimOfCastOp<tensor::CastOp>, DimOfShapedTypeOpInterface>(context);
|
||||
DimOfCastOp<tensor::CastOp>>(context);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
add_mlir_dialect_library(MLIRMemRefTransforms
|
||||
FoldSubViewOps.cpp
|
||||
ResolveShapedTypeResultDims.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/MemRef
|
||||
|
@ -9,9 +10,11 @@ add_mlir_dialect_library(MLIRMemRefTransforms
|
|||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRAffine
|
||||
MLIRInferTypeOpInterface
|
||||
MLIRMemRef
|
||||
MLIRPass
|
||||
MLIRStandard
|
||||
MLIRTensor
|
||||
MLIRTransforms
|
||||
MLIRVector
|
||||
)
|
||||
|
|
|
@ -0,0 +1,127 @@
|
|||
//===- ResolveShapedTypeResultDims.cpp - Resolve memref.dim ops of result values
|
||||
//-------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This pass resolves `memref.dim` operations of result values in terms of
|
||||
// shapes of their operands using the `InferShapedTypeOpInterface`.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
/// Helper method to get the `Value` that is the shape of the `resultIdx`-th
|
||||
/// result at dimension `dimIndex` from the `ShapedTypeOpInterface`.
|
||||
/// TODO(ravishankarm): This is better put as a interface utility method
|
||||
/// somewhere, but that would imply the interface will depend on the `tensor`
|
||||
/// dialect. Ideally maybe a utility method in the `tensor` dialect.
|
||||
static Value getResultDimFromShapeInterface(OpBuilder &builder, OpResult result,
|
||||
int64_t dimIndex) {
|
||||
unsigned resultNumber = result.getResultNumber();
|
||||
auto shapedTypeOp = dyn_cast<InferShapedTypeOpInterface>(result.getOwner());
|
||||
Location loc = result.getOwner()->getLoc();
|
||||
if (!shapedTypeOp)
|
||||
return nullptr;
|
||||
|
||||
// The interface exposes two methods, one that returns the shape of all the
|
||||
// results as `Value` and other that returns the shape as a list of
|
||||
// `SmallVector<Value>`. The former takes precedence over the latter. So first
|
||||
// check if the op implements the first interface method or the second, and
|
||||
// get the value to use appropriately.
|
||||
SmallVector<Value> reifiedResultShapes;
|
||||
if (succeeded(shapedTypeOp.reifyReturnTypeShapes(
|
||||
builder, result.getOwner()->getOperands(), reifiedResultShapes))) {
|
||||
if (reifiedResultShapes.size() <= resultNumber)
|
||||
return nullptr;
|
||||
Value resultShape = reifiedResultShapes[resultNumber];
|
||||
auto resultShapeType = resultShape.getType().dyn_cast<RankedTensorType>();
|
||||
if (!resultShapeType || !resultShapeType.getElementType().isa<IndexType>())
|
||||
return nullptr;
|
||||
return builder.create<tensor::ExtractOp>(
|
||||
loc, resultShape, builder.createOrFold<ConstantIndexOp>(loc, dimIndex));
|
||||
}
|
||||
|
||||
SmallVector<SmallVector<Value>> reifiedResultShapesPerDim;
|
||||
if (failed(shapedTypeOp.reifyReturnTypeShapesPerResultDim(
|
||||
builder, reifiedResultShapesPerDim)))
|
||||
return nullptr;
|
||||
if (reifiedResultShapesPerDim.size() <= resultNumber ||
|
||||
reifiedResultShapesPerDim[resultNumber].size() !=
|
||||
static_cast<size_t>(result.getType().cast<ShapedType>().getRank()))
|
||||
return nullptr;
|
||||
OpFoldResult valueOrAttr = reifiedResultShapesPerDim[resultNumber][dimIndex];
|
||||
if (auto attr = valueOrAttr.dyn_cast<Attribute>())
|
||||
return builder.createOrFold<ConstantIndexOp>(
|
||||
loc, attr.cast<IntegerAttr>().getInt());
|
||||
return valueOrAttr.get<Value>();
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// Fold dim of an operation that implements the InferShapedTypeOpInterface
|
||||
struct DimOfShapedTypeOpInterface : public OpRewritePattern<memref::DimOp> {
|
||||
using OpRewritePattern<memref::DimOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(memref::DimOp dimOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
OpResult dimValue = dimOp.memrefOrTensor().dyn_cast<OpResult>();
|
||||
if (!dimValue)
|
||||
return failure();
|
||||
auto shapedTypeOp =
|
||||
dyn_cast<InferShapedTypeOpInterface>(dimValue.getOwner());
|
||||
if (!shapedTypeOp)
|
||||
return failure();
|
||||
|
||||
Optional<int64_t> dimIndex = dimOp.getConstantIndex();
|
||||
if (!dimIndex)
|
||||
return failure();
|
||||
Value replacement =
|
||||
getResultDimFromShapeInterface(rewriter, dimValue, *dimIndex);
|
||||
if (!replacement)
|
||||
return failure();
|
||||
rewriter.replaceOp(dimOp, replacement);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Pass registration
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
|
||||
|
||||
struct ResolveShapedTypeResultDimsPass final
|
||||
: public ResolveShapedTypeResultDimsBase<ResolveShapedTypeResultDimsPass> {
|
||||
void runOnOperation() override;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void memref::populateResolveShapedTypeResultDimsPatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<DimOfShapedTypeOpInterface>(patterns.getContext());
|
||||
}
|
||||
|
||||
void ResolveShapedTypeResultDimsPass::runOnOperation() {
|
||||
RewritePatternSet patterns(&getContext());
|
||||
memref::populateResolveShapedTypeResultDimsPatterns(patterns);
|
||||
if (failed(applyPatternsAndFoldGreedily(getOperation()->getRegions(),
|
||||
std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> memref::createResolveShapedTypeResultDimsPass() {
|
||||
return std::make_unique<ResolveShapedTypeResultDimsPass>();
|
||||
}
|
|
@ -532,205 +532,6 @@ func @init_tensor_canonicalize() -> (tensor<4x5x?xf32>) {
|
|||
|
||||
// -----
|
||||
|
||||
func @init_tensor_static_dim() -> (index, index) {
|
||||
%c0 = constant 0 : index
|
||||
%c2 = constant 2 : index
|
||||
%c6 = constant 6 : index
|
||||
%0 = linalg.init_tensor [4, 5, %c6] : tensor<4x5x?xf32>
|
||||
%1 = memref.dim %0, %c2 : tensor<4x5x?xf32>
|
||||
%2 = memref.dim %0, %c0 : tensor<4x5x?xf32>
|
||||
return %1, %2 : index, index
|
||||
}
|
||||
// CHECK: func @init_tensor_static_dim
|
||||
// CHECK-DAG: %[[C4:.+]] = constant 4 : index
|
||||
// CHECK-DAG: %[[C6:.+]] = constant 6 : index
|
||||
// CHECK: return %[[C6]], %[[C4]]
|
||||
|
||||
// -----
|
||||
|
||||
func @init_tensor_dynamic_dim(%arg0 : index) -> (index) {
|
||||
%c2 = constant 2 : index
|
||||
%0 = linalg.init_tensor [4, 5, %arg0] : tensor<4x5x?xf32>
|
||||
%1 = memref.dim %0, %c2 : tensor<4x5x?xf32>
|
||||
return %1 : index
|
||||
}
|
||||
// CHECK: func @init_tensor_dynamic_dim
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
|
||||
// CHECK: return %[[ARG0]]
|
||||
|
||||
// -----
|
||||
|
||||
func @init_tensor_dynamic_dim2(%arg0 : index, %arg1 : index) -> (index, index) {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%0 = linalg.init_tensor [%arg0, %arg1] : tensor<?x?xf32>
|
||||
%1 = memref.dim %0, %c0 : tensor<?x?xf32>
|
||||
%2 = memref.dim %0, %c1 : tensor<?x?xf32>
|
||||
return %1, %2 : index, index
|
||||
}
|
||||
// CHECK: func @init_tensor_dynamic_dim2
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
|
||||
// CHECK: return %[[ARG0]], %[[ARG1]]
|
||||
|
||||
// -----
|
||||
|
||||
func @remove_dim_result_uses
|
||||
(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
|
||||
%arg2 : tensor<?x?xf32>) -> (index, index) {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%0 = linalg.generic
|
||||
{indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d2, d1)>,
|
||||
affine_map<(d0, d1, d2) -> (d0 + d1, d1 - d0)>],
|
||||
iterator_types = ["parallel", "parallel", "reduction"]}
|
||||
ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
|
||||
outs(%arg2 : tensor<?x?xf32>) {
|
||||
^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
|
||||
%1 = mulf %arg3, %arg4 : f32
|
||||
%2 = addf %1, %arg5 : f32
|
||||
linalg.yield %2 : f32
|
||||
} -> tensor<?x?xf32>
|
||||
%3 = memref.dim %0, %c0 : tensor<?x?xf32>
|
||||
%4 = memref.dim %0, %c1 : tensor<?x?xf32>
|
||||
return %3, %4 : index, index
|
||||
}
|
||||
// CHECK: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
|
||||
// CHECK: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (-s0 + s1)>
|
||||
// CHECK: func @remove_dim_result_uses
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
|
||||
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
|
||||
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
|
||||
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
|
||||
// CHECK-DAG: %[[T0:.+]] = memref.dim %[[ARG0]], %[[C0]]
|
||||
// CHECK-DAG: %[[T1:.+]] = memref.dim %[[ARG1]], %[[C1]]
|
||||
// CHECK: %[[T2:.+]] = affine.apply #[[MAP0]]()[%[[T0]], %[[T1]]]
|
||||
// CHECK-DAG: %[[T3:.+]] = memref.dim %[[ARG0]], %[[C0]]
|
||||
// CHECK-DAG: %[[T4:.+]] = memref.dim %[[ARG1]], %[[C1]]
|
||||
// CHECK: %[[T5:.+]] = affine.apply #[[MAP1]]()[%[[T3]], %[[T4]]]
|
||||
// CHECK: return %[[T2]], %[[T5]]
|
||||
|
||||
// -----
|
||||
|
||||
func @remove_dim_result_uses_outs
|
||||
(%arg0 : tensor<?xf32>, %arg1 : index) -> (index) {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%d0 = memref.dim %arg0, %c0 : tensor<?xf32>
|
||||
%0 = linalg.init_tensor [%d0, %arg1] : tensor<?x?xf32>
|
||||
%1 = linalg.generic
|
||||
{indexing_maps = [affine_map<(d0, d1) -> (d0)>,
|
||||
affine_map<(d0, d1) -> (d0, d1)>],
|
||||
iterator_types = ["parallel", "parallel"]}
|
||||
ins(%arg0 : tensor<?xf32>) outs(%0 : tensor<?x?xf32>) {
|
||||
^bb0(%arg2: f32, %arg3: f32) :
|
||||
linalg.yield %arg2 : f32
|
||||
} -> tensor<?x?xf32>
|
||||
%2 = memref.dim %1, %c1 : tensor<?x?xf32>
|
||||
return %2 : index
|
||||
}
|
||||
// CHECK: func @remove_dim_result_uses_outs
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
|
||||
// CHECK: return %[[ARG1]]
|
||||
|
||||
// -----
|
||||
|
||||
func @remove_dim_result_uses_sequence
|
||||
(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
|
||||
%arg2 : tensor<?x?xf32>) -> (index, index, index, index) {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
|
||||
outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%1 = memref.dim %0, %c0 : tensor<?x?xf32>
|
||||
%2 = memref.dim %0, %c1 : tensor<?x?xf32>
|
||||
%3 = linalg.generic
|
||||
{indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0)>,
|
||||
affine_map<(d0, d1, d2) -> (d0, d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d0, d2)>],
|
||||
iterator_types = ["parallel", "reduction", "parallel"]}
|
||||
ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
|
||||
outs(%0 : tensor<?x?xf32>) {
|
||||
^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
|
||||
%4 = mulf %arg3, %arg4 : f32
|
||||
%5 = addf %4, %arg5 : f32
|
||||
linalg.yield %5 : f32
|
||||
} -> tensor<?x?xf32>
|
||||
%6 = memref.dim %3, %c0 : tensor<?x?xf32>
|
||||
%7 = memref.dim %3, %c1 : tensor<?x?xf32>
|
||||
return %1, %2, %6, %7 : index, index, index, index
|
||||
}
|
||||
// CHECK-LABEL: func @remove_dim_result_uses_sequence
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
|
||||
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
|
||||
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
|
||||
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
|
||||
// CHECK-DAG: %[[T0:.+]] = memref.dim %[[ARG0]], %[[C0]]
|
||||
// CHECK-DAG: %[[T1:.+]] = memref.dim %[[ARG1]], %[[C1]]
|
||||
// CHECK-DAG: %[[T2:.+]] = memref.dim %[[ARG0]], %[[C1]]
|
||||
// CHECK-DAG: %[[T3:.+]] = memref.dim %[[ARG1]], %[[C1]]
|
||||
// CHECK: return %[[T0]], %[[T1]], %[[T2]], %[[T3]]
|
||||
|
||||
// -----
|
||||
|
||||
func @keep_result_dim_uses_sequence2
|
||||
(%arg0 : tensor<?xf32>, %arg1 : index) -> (index, index) {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%d0 = memref.dim %arg0, %c0 : tensor<?xf32>
|
||||
%0 = linalg.init_tensor [%d0, %arg1] : tensor<?x?xf32>
|
||||
%1 = linalg.generic
|
||||
{indexing_maps = [affine_map<(d0, d1) -> (d0)>,
|
||||
affine_map<(d0, d1) -> (d0, d1)>],
|
||||
iterator_types = ["parallel", "parallel"]}
|
||||
ins(%arg0 : tensor<?xf32>) outs(%0 : tensor<?x?xf32>) {
|
||||
^bb0(%arg2: f32, %arg3 : f32):
|
||||
linalg.yield %arg2 : f32
|
||||
} -> tensor<?x?xf32>
|
||||
%2 = memref.dim %1, %c0 : tensor<?x?xf32>
|
||||
%3 = memref.dim %1, %c1 : tensor<?x?xf32>
|
||||
return %2, %3 : index, index
|
||||
}
|
||||
// CHECK: func @keep_result_dim_uses_sequence2
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?xf32>
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
|
||||
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
|
||||
// CHECK-DAG: %[[T0:.+]] = memref.dim %[[ARG0]], %[[C0]]
|
||||
// CHECK: return %[[T0]], %[[ARG1]]
|
||||
|
||||
// -----
|
||||
|
||||
#map = affine_map<(d0) -> (d0)>
|
||||
|
||||
func @init_tensor_dim_of_linalg_result(%arg_0 : tensor<?xf32>,
|
||||
%arg_1: tensor<?xf32>) -> (index, index) {
|
||||
%0, %1 = linalg.generic {
|
||||
indexing_maps = [#map, #map, #map],
|
||||
iterator_types = ["parallel"]
|
||||
} ins(%arg_0 : tensor<?xf32>)
|
||||
outs(%arg_0, %arg_1 : tensor<?xf32>, tensor<?xf32>) {
|
||||
^bb0(%in: f32, %out_0: f32, %out_1: f32):
|
||||
linalg.yield %in, %in : f32, f32
|
||||
} -> (tensor<?xf32>, tensor<?xf32>)
|
||||
|
||||
%c0 = constant 0 : index
|
||||
%num_elem_0 = memref.dim %0, %c0 : tensor<?xf32>
|
||||
|
||||
%num_elem_1 = memref.dim %1, %c0 : tensor<?xf32>
|
||||
return %num_elem_0, %num_elem_1 : index, index
|
||||
}
|
||||
// CHECK: func @init_tensor_dim_of_linalg_result(
|
||||
// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<?xf32>
|
||||
// CHECK-SAME: %[[ARG_1:[a-zA-Z0-9_]+]]: tensor<?xf32>)
|
||||
// CHECK: %[[R0:.+]] = memref.dim %[[ARG_0]]
|
||||
// CHECK: %[[R1:.+]] = memref.dim %[[ARG_0]]
|
||||
// CHECK: return %[[R0]], %[[R1]]
|
||||
|
||||
// -----
|
||||
|
||||
func @init_tensor_reshape_expansion(%arg0 : index) -> tensor<2x3x5x4x?x7xf32> {
|
||||
%0 = linalg.init_tensor [6, 5, %arg0] : tensor<6x5x?xf32>
|
||||
%1 = linalg.tensor_expand_shape %0 [[0, 1], [2], [3, 4, 5]]
|
||||
|
@ -740,9 +541,12 @@ func @init_tensor_reshape_expansion(%arg0 : index) -> tensor<2x3x5x4x?x7xf32> {
|
|||
// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
|
||||
// CHECK: func @init_tensor_reshape_expansion
|
||||
// CHECK-SAME: %[[ARG0:.+]]: index
|
||||
// CHECK: %[[T0:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]]
|
||||
// CHECK: %[[T1:.+]] = linalg.init_tensor [2, 3, 5, 4, %[[T0]], 7]
|
||||
// CHECK: return %[[T1]]
|
||||
// CHECK: %[[C2:.+]] = constant 2
|
||||
// CHECK: %[[INIT1:.+]] = linalg.init_tensor [6, 5, %[[ARG0]]]
|
||||
// CHECK: %[[D0:.+]] = memref.dim %[[INIT1]], %[[C2]]
|
||||
// CHECK: %[[T0:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
|
||||
// CHECK: %[[INIT2:.+]] = linalg.init_tensor [2, 3, 5, 4, %[[T0]], 7]
|
||||
// CHECK: return %[[INIT2]]
|
||||
|
||||
// -----
|
||||
|
||||
|
@ -755,9 +559,12 @@ func @init_tensor_reshape_collapse(%arg0 : index) -> tensor<6x5x?xf32> {
|
|||
// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 28)>
|
||||
// CHECK: func @init_tensor_reshape_collapse
|
||||
// CHECK-SAME: %[[ARG0:.+]]: index
|
||||
// CHECK: %[[T0:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]]
|
||||
// CHECK: %[[T1:.+]] = linalg.init_tensor [6, 5, %[[T0]]]
|
||||
// CHECK: return %[[T1]]
|
||||
// CHECK: %[[C4:.+]] = constant 4
|
||||
// CHECK: %[[INIT1:.+]] = linalg.init_tensor [2, 3, 5, 4, %[[ARG0]], 7]
|
||||
// CHECK: %[[D0:.+]] = memref.dim %[[INIT1]], %[[C4]]
|
||||
// CHECK: %[[T0:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
|
||||
// CHECK: %[[INIT2:.+]] = linalg.init_tensor [6, 5, %[[T0]]]
|
||||
// CHECK: return %[[INIT2]]
|
||||
|
||||
// -----
|
||||
|
||||
|
@ -906,54 +713,6 @@ func @pad_tensor_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
|
|||
} : tensor<5x6xf32> to tensor<5x6xf32>
|
||||
return %0 : tensor<5x6xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @dim_reshape_expansion(%arg0 : tensor<6x5x?xf32>) -> (index, index, index)
|
||||
{
|
||||
%c1 = constant 1 : index
|
||||
%c3 = constant 3 : index
|
||||
%c4 = constant 4 : index
|
||||
%0 = linalg.tensor_expand_shape %arg0 [[0, 1], [2], [3, 4, 5]]
|
||||
: tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32>
|
||||
%1 = memref.dim %0, %c1 : tensor<2x3x5x4x?x7xf32>
|
||||
%2 = memref.dim %0, %c3 : tensor<2x3x5x4x?x7xf32>
|
||||
%3 = memref.dim %0, %c4 : tensor<2x3x5x4x?x7xf32>
|
||||
return %1, %2, %3 : index, index, index
|
||||
}
|
||||
// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
|
||||
// CHECK: func @dim_reshape_expansion
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<6x5x?xf32>
|
||||
// CHECK-DAG: %[[C2:.+]] = constant 2 : index
|
||||
// CHECK-DAG: %[[C3:.+]] = constant 3 : index
|
||||
// CHECK-DAG: %[[C4:.+]] = constant 4 : index
|
||||
// CHECK: %[[D0:.+]] = memref.dim %[[ARG0]], %[[C2]]
|
||||
// CHECK: %[[D1:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
|
||||
// CHECK: return %[[C3]], %[[C4]], %[[D1]]
|
||||
|
||||
// -----
|
||||
|
||||
func @dim_reshape_collapse(%arg0 : tensor<2x3x5x4x?x7xf32>) -> (index, index)
|
||||
{
|
||||
%c1 = constant 1 : index
|
||||
%c2 = constant 2 : index
|
||||
%0 = linalg.tensor_collapse_shape %arg0 [[0, 1], [2], [3, 4, 5]]
|
||||
: tensor<2x3x5x4x?x7xf32> into tensor<6x5x?xf32>
|
||||
%1 = memref.dim %0, %c1 : tensor<6x5x?xf32>
|
||||
%2 = memref.dim %0, %c2 : tensor<6x5x?xf32>
|
||||
return %1, %2 : index, index
|
||||
}
|
||||
// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 28)>
|
||||
// CHECK: func @dim_reshape_collapse
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x3x5x4x?x7xf32>
|
||||
// CHECK-DAG: %[[C4:.+]] = constant 4 : index
|
||||
// CHECK-DAG: %[[C5:.+]] = constant 5 : index
|
||||
// CHECK: %[[D0:.+]] = memref.dim %[[ARG0]], %[[C4]]
|
||||
// CHECK: %[[D1:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
|
||||
// CHECK: return %[[C5]], %[[D1]]
|
||||
|
||||
// -----
|
||||
|
||||
func @propogate_casts(%arg0 : tensor<?x?xf32>, %arg1 : f32, %arg2 : index,
|
||||
%arg3 : index) -> tensor<?x?xf32> {
|
||||
%c0 = constant 0 : index
|
||||
|
@ -1083,41 +842,6 @@ func @fold_tiled_loop_inputs(%A: memref<192xf32>, %A_tensor: tensor<192xf32>,
|
|||
|
||||
// -----
|
||||
|
||||
func @dim_of_pad_op(%arg0 : tensor<2x?x?xf32>, %arg1 : index, %arg2 : index,
|
||||
%arg3: f32) -> (index, index, index)
|
||||
{
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%c2 = constant 2 : index
|
||||
%c3 = constant 3 : index
|
||||
%c4 = constant 4 : index
|
||||
%c5 = constant 5 : index
|
||||
%0 = linalg.pad_tensor %arg0 low[%c3, %arg1, %c4] high[7, %c5, %arg2] {
|
||||
^bb0(%arg4: index, %arg5: index, %arg6: index):
|
||||
linalg.yield %arg3 : f32
|
||||
} : tensor<2x?x?xf32> to tensor<?x?x?xf32>
|
||||
%1 = memref.dim %0, %c0 : tensor<?x?x?xf32>
|
||||
%2 = memref.dim %0, %c1 : tensor<?x?x?xf32>
|
||||
%3 = memref.dim %0, %c2 : tensor<?x?x?xf32>
|
||||
return %1, %2, %3 : index, index, index
|
||||
}
|
||||
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 5)>
|
||||
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 4)>
|
||||
// CHECK: func @dim_of_pad_op
|
||||
// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]+]]: tensor<2x?x?xf32>
|
||||
// CHECK-SAME: %[[ARG1:[A-Za-z0-9_]+]]: index
|
||||
// CHECK-SAME: %[[ARG2:[A-Za-z0-9_]+]]: index
|
||||
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
|
||||
// CHECK-DAG: %[[C2:.+]] = constant 2 : index
|
||||
// CHECK-DAG: %[[C12:.+]] = constant 12 : index
|
||||
// CHECK: %[[IN_DIM1:.+]] = memref.dim %[[ARG0]], %[[C1]]
|
||||
// CHECK: %[[OUT_DIM1:.+]] = affine.apply #[[MAP0]]()[%[[ARG1]], %[[IN_DIM1]]]
|
||||
// CHECK: %[[IN_DIM2:.+]] = memref.dim %[[ARG0]], %[[C2]]
|
||||
// CHECK: %[[OUT_DIM2:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[IN_DIM2]]]
|
||||
// CHECK: return %[[C12]], %[[OUT_DIM1]], %[[OUT_DIM2]]
|
||||
|
||||
// -----
|
||||
|
||||
#map = affine_map<(d0, d1) -> (d0, d1)>
|
||||
|
||||
func @indexed_generic(%arg0: memref<?x?xindex>, %arg1: memref<?x?xindex>) {
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: mlir-opt -pass-pipeline="func(test-linalg-tile-and-fuse{tile-sizes=16,32,64}),canonicalize,cse" -split-input-file %s | FileCheck %s
|
||||
// RUN: mlir-opt -pass-pipeline="func(test-linalg-tile-and-fuse{tile-sizes=16,32,64}),resolve-shaped-type-result-dims,canonicalize,cse" -split-input-file %s | FileCheck %s
|
||||
|
||||
module {
|
||||
func @three_op_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
// RUN: mlir-opt %s -test-linalg-tensor-fusion-transform-patterns -canonicalize -cse --split-input-file | FileCheck %s
|
||||
// RUN: mlir-opt %s -test-linalg-tiled-loop-fusion-transform-patterns -canonicalize -cse --split-input-file | FileCheck %s --check-prefix=TLOOP
|
||||
// RUN: mlir-opt %s -test-linalg-tensor-fusion-transform-patterns -resolve-shaped-type-result-dims -canonicalize -cse --split-input-file | FileCheck %s
|
||||
// RUN: mlir-opt %s -test-linalg-tiled-loop-fusion-transform-patterns -resolve-shaped-type-result-dims -canonicalize -cse --split-input-file | FileCheck %s --check-prefix=TLOOP
|
||||
|
||||
module {
|
||||
func @matmul_fusion(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
|
||||
|
|
|
@ -0,0 +1,278 @@
|
|||
// RUN: mlir-opt -resolve-shaped-type-result-dims -split-input-file %s | FileCheck %s
|
||||
|
||||
func @init_tensor_static_dim() -> (index, index) {
|
||||
%c0 = constant 0 : index
|
||||
%c2 = constant 2 : index
|
||||
%c6 = constant 6 : index
|
||||
%0 = linalg.init_tensor [4, 5, %c6] : tensor<4x5x?xf32>
|
||||
%1 = memref.dim %0, %c2 : tensor<4x5x?xf32>
|
||||
%2 = memref.dim %0, %c0 : tensor<4x5x?xf32>
|
||||
return %1, %2 : index, index
|
||||
}
|
||||
// CHECK: func @init_tensor_static_dim
|
||||
// CHECK-DAG: %[[C4:.+]] = constant 4 : index
|
||||
// CHECK-DAG: %[[C6:.+]] = constant 6 : index
|
||||
// CHECK: return %[[C6]], %[[C4]]
|
||||
|
||||
// -----
|
||||
|
||||
func @init_tensor_dynamic_dim(%arg0 : index) -> (index) {
|
||||
%c2 = constant 2 : index
|
||||
%0 = linalg.init_tensor [4, 5, %arg0] : tensor<4x5x?xf32>
|
||||
%1 = memref.dim %0, %c2 : tensor<4x5x?xf32>
|
||||
return %1 : index
|
||||
}
|
||||
// CHECK: func @init_tensor_dynamic_dim
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
|
||||
// CHECK: return %[[ARG0]]
|
||||
|
||||
// -----
|
||||
|
||||
func @init_tensor_dynamic_dim2(%arg0 : index, %arg1 : index) -> (index, index) {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%0 = linalg.init_tensor [%arg0, %arg1] : tensor<?x?xf32>
|
||||
%1 = memref.dim %0, %c0 : tensor<?x?xf32>
|
||||
%2 = memref.dim %0, %c1 : tensor<?x?xf32>
|
||||
return %1, %2 : index, index
|
||||
}
|
||||
// CHECK: func @init_tensor_dynamic_dim2
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
|
||||
// CHECK: return %[[ARG0]], %[[ARG1]]
|
||||
|
||||
// -----
|
||||
|
||||
func @remove_dim_result_uses
|
||||
(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
|
||||
%arg2 : tensor<?x?xf32>) -> (index, index) {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%0 = linalg.generic
|
||||
{indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d2, d1)>,
|
||||
affine_map<(d0, d1, d2) -> (d0 + d1, d1 - d0)>],
|
||||
iterator_types = ["parallel", "parallel", "reduction"]}
|
||||
ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
|
||||
outs(%arg2 : tensor<?x?xf32>) {
|
||||
^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
|
||||
%1 = mulf %arg3, %arg4 : f32
|
||||
%2 = addf %1, %arg5 : f32
|
||||
linalg.yield %2 : f32
|
||||
} -> tensor<?x?xf32>
|
||||
%3 = memref.dim %0, %c0 : tensor<?x?xf32>
|
||||
%4 = memref.dim %0, %c1 : tensor<?x?xf32>
|
||||
return %3, %4 : index, index
|
||||
}
|
||||
// CHECK: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
|
||||
// CHECK: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s1 - s0)>
|
||||
// CHECK: func @remove_dim_result_uses
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
|
||||
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
|
||||
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
|
||||
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
|
||||
// CHECK-DAG: %[[T0:.+]] = memref.dim %[[ARG0]], %[[C0]]
|
||||
// CHECK-DAG: %[[T1:.+]] = memref.dim %[[ARG1]], %[[C1]]
|
||||
// CHECK: %[[T2:.+]] = affine.apply #[[MAP0]]()[%[[T0]], %[[T1]]]
|
||||
// CHECK-DAG: %[[T3:.+]] = memref.dim %[[ARG0]], %[[C0]]
|
||||
// CHECK-DAG: %[[T4:.+]] = memref.dim %[[ARG1]], %[[C1]]
|
||||
// CHECK: %[[T5:.+]] = affine.apply #[[MAP1]]()[%[[T3]], %[[T4]]]
|
||||
// CHECK: return %[[T2]], %[[T5]]
|
||||
|
||||
// -----
|
||||
|
||||
func @remove_dim_result_uses_outs
|
||||
(%arg0 : tensor<?xf32>, %arg1 : index) -> (index) {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%d0 = memref.dim %arg0, %c0 : tensor<?xf32>
|
||||
%0 = linalg.init_tensor [%d0, %arg1] : tensor<?x?xf32>
|
||||
%1 = linalg.generic
|
||||
{indexing_maps = [affine_map<(d0, d1) -> (d0)>,
|
||||
affine_map<(d0, d1) -> (d0, d1)>],
|
||||
iterator_types = ["parallel", "parallel"]}
|
||||
ins(%arg0 : tensor<?xf32>) outs(%0 : tensor<?x?xf32>) {
|
||||
^bb0(%arg2: f32, %arg3: f32) :
|
||||
linalg.yield %arg2 : f32
|
||||
} -> tensor<?x?xf32>
|
||||
%2 = memref.dim %1, %c1 : tensor<?x?xf32>
|
||||
return %2 : index
|
||||
}
|
||||
// CHECK: func @remove_dim_result_uses_outs
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
|
||||
// CHECK: return %[[ARG1]]
|
||||
|
||||
// -----
|
||||
|
||||
func @remove_dim_result_uses_sequence
|
||||
(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
|
||||
%arg2 : tensor<?x?xf32>) -> (index, index, index, index) {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
|
||||
outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%1 = memref.dim %0, %c0 : tensor<?x?xf32>
|
||||
%2 = memref.dim %0, %c1 : tensor<?x?xf32>
|
||||
%3 = linalg.generic
|
||||
{indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0)>,
|
||||
affine_map<(d0, d1, d2) -> (d0, d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d0, d2)>],
|
||||
iterator_types = ["parallel", "reduction", "parallel"]}
|
||||
ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
|
||||
outs(%0 : tensor<?x?xf32>) {
|
||||
^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
|
||||
%4 = mulf %arg3, %arg4 : f32
|
||||
%5 = addf %4, %arg5 : f32
|
||||
linalg.yield %5 : f32
|
||||
} -> tensor<?x?xf32>
|
||||
%6 = memref.dim %3, %c0 : tensor<?x?xf32>
|
||||
%7 = memref.dim %3, %c1 : tensor<?x?xf32>
|
||||
return %1, %2, %6, %7 : index, index, index, index
|
||||
}
|
||||
// CHECK-LABEL: func @remove_dim_result_uses_sequence
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
|
||||
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
|
||||
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
|
||||
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
|
||||
// CHECK-DAG: %[[T0:.+]] = memref.dim %[[ARG0]], %[[C0]]
|
||||
// CHECK-DAG: %[[T1:.+]] = memref.dim %[[ARG1]], %[[C1]]
|
||||
// CHECK-DAG: %[[T2:.+]] = memref.dim %[[ARG0]], %[[C1]]
|
||||
// CHECK-DAG: %[[T3:.+]] = memref.dim %[[ARG1]], %[[C1]]
|
||||
// CHECK: return %[[T0]], %[[T1]], %[[T2]], %[[T3]]
|
||||
|
||||
// -----
|
||||
|
||||
func @keep_result_dim_uses_sequence2
|
||||
(%arg0 : tensor<?xf32>, %arg1 : index) -> (index, index) {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%d0 = memref.dim %arg0, %c0 : tensor<?xf32>
|
||||
%0 = linalg.init_tensor [%d0, %arg1] : tensor<?x?xf32>
|
||||
%1 = linalg.generic
|
||||
{indexing_maps = [affine_map<(d0, d1) -> (d0)>,
|
||||
affine_map<(d0, d1) -> (d0, d1)>],
|
||||
iterator_types = ["parallel", "parallel"]}
|
||||
ins(%arg0 : tensor<?xf32>) outs(%0 : tensor<?x?xf32>) {
|
||||
^bb0(%arg2: f32, %arg3 : f32):
|
||||
linalg.yield %arg2 : f32
|
||||
} -> tensor<?x?xf32>
|
||||
%2 = memref.dim %1, %c0 : tensor<?x?xf32>
|
||||
%3 = memref.dim %1, %c1 : tensor<?x?xf32>
|
||||
return %2, %3 : index, index
|
||||
}
|
||||
// CHECK: func @keep_result_dim_uses_sequence2
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?xf32>
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
|
||||
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
|
||||
// CHECK-DAG: %[[T0:.+]] = memref.dim %[[ARG0]], %[[C0]]
|
||||
// CHECK: return %[[T0]], %[[ARG1]]
|
||||
|
||||
// -----
|
||||
|
||||
#map = affine_map<(d0) -> (d0)>
|
||||
|
||||
func @init_tensor_dim_of_linalg_result(%arg_0 : tensor<?xf32>,
|
||||
%arg_1: tensor<?xf32>) -> (index, index) {
|
||||
%0, %1 = linalg.generic {
|
||||
indexing_maps = [#map, #map, #map],
|
||||
iterator_types = ["parallel"]
|
||||
} ins(%arg_0 : tensor<?xf32>)
|
||||
outs(%arg_0, %arg_1 : tensor<?xf32>, tensor<?xf32>) {
|
||||
^bb0(%in: f32, %out_0: f32, %out_1: f32):
|
||||
linalg.yield %in, %in : f32, f32
|
||||
} -> (tensor<?xf32>, tensor<?xf32>)
|
||||
|
||||
%c0 = constant 0 : index
|
||||
%num_elem_0 = memref.dim %0, %c0 : tensor<?xf32>
|
||||
|
||||
%num_elem_1 = memref.dim %1, %c0 : tensor<?xf32>
|
||||
return %num_elem_0, %num_elem_1 : index, index
|
||||
}
|
||||
// CHECK: func @init_tensor_dim_of_linalg_result(
|
||||
// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor<?xf32>
|
||||
// CHECK-SAME: %[[ARG_1:[a-zA-Z0-9_]+]]: tensor<?xf32>)
|
||||
// CHECK: %[[R0:.+]] = memref.dim %[[ARG_0]]
|
||||
// CHECK: %[[R1:.+]] = memref.dim %[[ARG_0]]
|
||||
// CHECK: return %[[R0]], %[[R1]]
|
||||
|
||||
// -----
|
||||
|
||||
func @dim_reshape_expansion(%arg0 : tensor<6x5x?xf32>) -> (index, index, index)
|
||||
{
|
||||
%c1 = constant 1 : index
|
||||
%c3 = constant 3 : index
|
||||
%c4 = constant 4 : index
|
||||
%0 = linalg.tensor_expand_shape %arg0 [[0, 1], [2], [3, 4, 5]]
|
||||
: tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32>
|
||||
%1 = memref.dim %0, %c1 : tensor<2x3x5x4x?x7xf32>
|
||||
%2 = memref.dim %0, %c3 : tensor<2x3x5x4x?x7xf32>
|
||||
%3 = memref.dim %0, %c4 : tensor<2x3x5x4x?x7xf32>
|
||||
return %1, %2, %3 : index, index, index
|
||||
}
|
||||
// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
|
||||
// CHECK: func @dim_reshape_expansion
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<6x5x?xf32>
|
||||
// CHECK-DAG: %[[C2:.+]] = constant 2 : index
|
||||
// CHECK-DAG: %[[C3:.+]] = constant 3 : index
|
||||
// CHECK-DAG: %[[C4:.+]] = constant 4 : index
|
||||
// CHECK: %[[D0:.+]] = memref.dim %[[ARG0]], %[[C2]]
|
||||
// CHECK: %[[D1:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
|
||||
// CHECK: return %[[C3]], %[[C4]], %[[D1]]
|
||||
|
||||
// -----
|
||||
|
||||
func @dim_reshape_collapse(%arg0 : tensor<2x3x5x4x?x7xf32>) -> (index, index)
|
||||
{
|
||||
%c1 = constant 1 : index
|
||||
%c2 = constant 2 : index
|
||||
%0 = linalg.tensor_collapse_shape %arg0 [[0, 1], [2], [3, 4, 5]]
|
||||
: tensor<2x3x5x4x?x7xf32> into tensor<6x5x?xf32>
|
||||
%1 = memref.dim %0, %c1 : tensor<6x5x?xf32>
|
||||
%2 = memref.dim %0, %c2 : tensor<6x5x?xf32>
|
||||
return %1, %2 : index, index
|
||||
}
|
||||
// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 28)>
|
||||
// CHECK: func @dim_reshape_collapse
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x3x5x4x?x7xf32>
|
||||
// CHECK-DAG: %[[C4:.+]] = constant 4 : index
|
||||
// CHECK-DAG: %[[C5:.+]] = constant 5 : index
|
||||
// CHECK: %[[D0:.+]] = memref.dim %[[ARG0]], %[[C4]]
|
||||
// CHECK: %[[D1:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
|
||||
// CHECK: return %[[C5]], %[[D1]]
|
||||
|
||||
// -----
|
||||
|
||||
func @dim_of_pad_op(%arg0 : tensor<2x?x?xf32>, %arg1 : index, %arg2 : index,
|
||||
%arg3: f32) -> (index, index, index)
|
||||
{
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%c2 = constant 2 : index
|
||||
%c3 = constant 3 : index
|
||||
%c4 = constant 4 : index
|
||||
%c5 = constant 5 : index
|
||||
%0 = linalg.pad_tensor %arg0 low[%c3, %arg1, %c4] high[7, %c5, %arg2] {
|
||||
^bb0(%arg4: index, %arg5: index, %arg6: index):
|
||||
linalg.yield %arg3 : f32
|
||||
} : tensor<2x?x?xf32> to tensor<?x?x?xf32>
|
||||
%1 = memref.dim %0, %c0 : tensor<?x?x?xf32>
|
||||
%2 = memref.dim %0, %c1 : tensor<?x?x?xf32>
|
||||
%3 = memref.dim %0, %c2 : tensor<?x?x?xf32>
|
||||
return %1, %2, %3 : index, index, index
|
||||
}
|
||||
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s1 + s0 + 5)>
|
||||
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s1 + s0 + 4)>
|
||||
// CHECK: func @dim_of_pad_op
|
||||
// CHECK-SAME: %[[ARG0:[A-Za-z0-9_]+]]: tensor<2x?x?xf32>
|
||||
// CHECK-SAME: %[[ARG1:[A-Za-z0-9_]+]]: index
|
||||
// CHECK-SAME: %[[ARG2:[A-Za-z0-9_]+]]: index
|
||||
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
|
||||
// CHECK-DAG: %[[C2:.+]] = constant 2 : index
|
||||
// CHECK-DAG: %[[C12:.+]] = constant 12 : index
|
||||
// CHECK: %[[IN_DIM1:.+]] = memref.dim %[[ARG0]], %[[C1]]
|
||||
// CHECK: %[[OUT_DIM1:.+]] = affine.apply #[[MAP0]]()[%[[ARG1]], %[[IN_DIM1]]]
|
||||
// CHECK: %[[IN_DIM2:.+]] = memref.dim %[[ARG0]], %[[C2]]
|
||||
// CHECK: %[[OUT_DIM2:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[IN_DIM2]]]
|
||||
// CHECK: return %[[C12]], %[[OUT_DIM1]], %[[OUT_DIM2]]
|
|
@ -205,16 +205,14 @@ func @conv_tensors_dynamic(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?x
|
|||
|
||||
// CHECK: #[[BOUND8_MAP:.+]] = affine_map<(d0)[s0] -> (8, -d0 + s0)>
|
||||
// CHECK: #[[BOUND8_MAP_2:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s0, 8, -d0 + s1)>
|
||||
// CHECK: #[[BOUND8_MAP_3:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 8)>
|
||||
// CHECK: #[[BOUND16_MAP:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)>
|
||||
// CHECK: #[[X2_MAP:.+]] = affine_map<(d0) -> (d0 * 2)>
|
||||
// CHECK: #[[INPUT_BOUND:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * 2 + s0 - 2, d1 * -2 + s1)>
|
||||
// CHECK: #[[BOUND16_MAP_2:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 16)>
|
||||
// CHECK: #[[BOUND16_MAP_2:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s0, 16, -d0 + s1)>
|
||||
// CHECK: #[[BOUND4_MAP:.+]] = affine_map<(d0)[s0] -> (4, -d0 + s0)>
|
||||
// CHECK: #[[BOUND2_MAP:.+]] = affine_map<(d0)[s0] -> (2, -d0 + s0)>
|
||||
// CHECK: #[[BOUND4_MAP_2:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 4)>
|
||||
// CHECK: #[[BOUND4_MAP_2:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s0, 4, -d0 + s1)>
|
||||
// CHECK: #[[BOUND2_MAP_2:.+]] = affine_map<(d0, d1)[s0, s1] -> (-d0 + s0, 2, -d1 + s1)>
|
||||
// CHECK: #[[BOUND2_MAP_3:.+]] = affine_map<(d0, d1)[s0] -> (-d0 + s0, 2, -d1 + s0)>
|
||||
|
||||
// CHECK: func @conv_tensors_dynamic
|
||||
// CHECK-SAME: (%[[INPUT]]: tensor<?x?x?x?xf32>, %[[FILTER]]: tensor<?x?x?x?xf32>, %[[ELEM]]: tensor<?x?x?x?xf32>)
|
||||
|
@ -240,16 +238,20 @@ func @conv_tensors_dynamic(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?x
|
|||
// CHECK-DAG: %[[INPUT_C:.+]] = memref.dim %[[INPUT]], %[[C3]] : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[FILTER_IC:.+]] = memref.dim %[[FILTER]], %[[C2]] : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[FILTER_OC:.+]] = memref.dim %[[FILTER]], %[[C3]] : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[FILL_N:.+]] = memref.dim %[[FILL]], %[[C0]] : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[FILL_H:.+]] = memref.dim %[[FILL]], %[[C1]] : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[FILL_W:.+]] = memref.dim %[[FILL]], %[[C2]] : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[FILL_C:.+]] = memref.dim %[[FILL]], %[[C3]] : tensor<?x?x?x?xf32>
|
||||
|
||||
// CHECK: scf.for %[[IV0:.+]] = %{{.+}} to %[[ELEM_N]] step %{{.+}} iter_args(%{{.+}} = %[[FILL]])
|
||||
// CHECK-NEXT: %[[SIZE_ELEM_N:.+]] = affine.min #[[BOUND8_MAP]](%[[IV0]])[%[[ELEM_N]]]
|
||||
// CHECK-NEXT: %[[SIZE_INPUT_N:.+]] = affine.min #[[BOUND8_MAP_2]](%[[IV0]])[%[[INPUT_N]], %[[ELEM_N]]]
|
||||
// CHECK-NEXT: %[[SIZE_ELEM_N_2:.+]] = affine.min #[[BOUND8_MAP_3]](%[[IV0]])[%[[ELEM_N]]]
|
||||
// CHECK-NEXT: %[[SIZE_ELEM_N_2:.+]] = affine.min #[[BOUND8_MAP_2]](%[[IV0]])[%[[FILL_N]], %[[ELEM_N]]]
|
||||
// CHECK-NEXT: scf.for %[[IV1:.+]] = %{{.+}} to %[[ELEM_OH]]
|
||||
// CHECK-NEXT: %[[SIZE_ELEM_OH:.+]] = affine.min #[[BOUND16_MAP]](%[[IV1]])[%[[ELEM_OH]]]
|
||||
// CHECK-NEXT: %[[OFFSET_OH:.+]] = affine.apply #[[X2_MAP]](%[[IV1]])
|
||||
// CHECK-NEXT: %[[SIZE_INPUT_H:.+]] = affine.min #[[INPUT_BOUND]](%[[SIZE_ELEM_OH]], %[[IV1]])[%[[FILTER_H]], %[[INPUT_H]]]
|
||||
// CHECK-NEXT: %[[SIZE_ELEM_OH_2:.+]] = affine.min #[[BOUND16_MAP_2]](%[[IV1]])[%[[ELEM_OH]]]
|
||||
// CHECK-NEXT: %[[SIZE_ELEM_OH_2:.+]] = affine.min #[[BOUND16_MAP_2]](%[[IV1]])[%[[FILL_H]], %[[ELEM_OH]]]
|
||||
// CHECK-NEXT: scf.for %[[IV2:.+]] = %{{.+}} to %[[ELEM_OW]]
|
||||
// CHECK-NEXT: %[[SIZE_ELEM_OW:.+]] = affine.min #[[BOUND4_MAP]](%[[IV2]])[%[[ELEM_OW]]]
|
||||
// CHECK-NEXT: %[[SIZE_ELEM_OC:.+]] = affine.min #[[BOUND2_MAP]](%[[IV2]])[%[[ELEM_OC]]]
|
||||
|
@ -257,7 +259,7 @@ func @conv_tensors_dynamic(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?x
|
|||
// CHECK-NEXT: %[[SIZE_INPUT_W:.+]] = affine.min #[[INPUT_BOUND]](%[[SIZE_ELEM_OW]], %[[IV2]])[%[[FILTER_W]], %[[INPUT_W]]]
|
||||
// CHECK-NEXT: %[[ST_INPUT:.+]] = subtensor %[[INPUT]][%[[IV0]], %[[OFFSET_OH]], %[[OFFSET_OW]], 0]
|
||||
// CHECK-SAME: [%[[SIZE_INPUT_N]], %[[SIZE_INPUT_H]], %[[SIZE_INPUT_W]], %[[INPUT_C]]]
|
||||
// CHECK-NEXT: %[[SIZE_ELEM_OW_2:.+]] = affine.min #[[BOUND4_MAP_2]](%[[IV2]])[%[[ELEM_OW]]]
|
||||
// CHECK-NEXT: %[[SIZE_ELEM_OW_2:.+]] = affine.min #[[BOUND4_MAP_2]](%[[IV2]])[%[[FILL_W]], %[[ELEM_OW]]]
|
||||
// CHECK-NEXT: scf.for %[[IV3:.+]] = %{{.+}} to %[[ELEM_OC]] step %{{.+}} iter_args(%[[ARG:[a-z0-9]+]]
|
||||
// CHECK-NEXT: %[[ST_ELEM:.+]] = subtensor %[[ELEM]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
|
||||
// CHECK-SAME: [%[[SIZE_ELEM_N]], %[[SIZE_ELEM_OH]], %[[SIZE_ELEM_OW]], %[[SIZE_ELEM_OC]]]
|
||||
|
@ -266,7 +268,7 @@ func @conv_tensors_dynamic(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?x
|
|||
// CHECK-NEXT: %[[SIZE_ELEM_OC_2:.+]] = affine.min #[[BOUND2_MAP_2]](%[[IV3]], %[[IV2]])[%[[FILTER_OC]], %[[ELEM_OC]]]
|
||||
// CHECK-NEXT: %[[ST_FILTER:.+]] = subtensor %[[FILTER]][0, 0, 0, %[[IV3]]]
|
||||
// CHECK-SAME: [%[[FILTER_H]], %[[FILTER_W]], %[[FILTER_IC]], %[[SIZE_ELEM_OC_2]]]
|
||||
// CHECK-NEXT: %[[SIZE_ELEM_OC_3:.+]] = affine.min #[[BOUND2_MAP_3]](%[[IV3]], %[[IV2]])[%[[ELEM_OC]]]
|
||||
// CHECK-NEXT: %[[SIZE_ELEM_OC_3:.+]] = affine.min #[[BOUND2_MAP_2]](%[[IV3]], %[[IV2]])[%[[FILL_C]], %[[ELEM_OC]]]
|
||||
// CHECK-NEXT: %[[ST_FILL:.+]] = subtensor %[[FILL]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
|
||||
// CHECK-SAME: [%[[SIZE_ELEM_N_2]], %[[SIZE_ELEM_OH_2]], %[[SIZE_ELEM_OW_2]], %[[SIZE_ELEM_OC_3]]]
|
||||
// CHECK-NEXT: %[[ST_CONV:.+]] = linalg.conv_2d_input_nhwc_filter_hwcf
|
||||
|
|
|
@ -0,0 +1,88 @@
|
|||
// RUN: mlir-opt %s -resolve-shaped-type-result-dims -split-input-file | FileCheck %s
|
||||
|
||||
func @result_shape(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
|
||||
-> (index, index, index, index, index) {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%c2 = constant 2 : index
|
||||
%0:2 = "test.op_with_result_shape_interface"(%arg0, %arg1)
|
||||
: (tensor<2x3x?xf32>, tensor<?x5xf32>) -> (tensor<?x5xf32>, tensor<2x3x?xf32>)
|
||||
%1 = memref.dim %0#0, %c0 : tensor<?x5xf32>
|
||||
%2 = memref.dim %0#0, %c1 : tensor<?x5xf32>
|
||||
%3 = memref.dim %0#1, %c0 : tensor<2x3x?xf32>
|
||||
%4 = memref.dim %0#1, %c1 : tensor<2x3x?xf32>
|
||||
%5 = memref.dim %0#1, %c2 : tensor<2x3x?xf32>
|
||||
return %1, %2, %3, %4, %5 : index, index, index, index, index
|
||||
}
|
||||
// CHECK-LABEL: func @result_shape(
|
||||
// CHECK-SAME: %[[ARG_0:[a-z0-9]*]]: tensor<2x3x?xf32>
|
||||
// CHECK-SAME: %[[ARG_1:[a-z0-9]*]]: tensor<?x5xf32>)
|
||||
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
|
||||
// CHECK-DAG: %[[C2:.+]] = constant 2 : index
|
||||
// CHECK-DAG: %[[C3:.+]] = constant 3 : index
|
||||
// CHECK-DAG: %[[C5:.+]] = constant 5 : index
|
||||
// CHECK-DAG: %[[D0:.+]] = memref.dim %[[ARG_1]], %[[C0]]
|
||||
// CHECK-DAG: %[[S0:.+]] = tensor.from_elements %[[D0]], %[[C5]]
|
||||
// CHECK-DAG: %[[D0_OUT:.+]] = tensor.extract %[[S0]][%[[C0]]]
|
||||
// CHECK-DAG: %[[D1:.+]] = memref.dim %[[ARG_0]], %[[C2]]
|
||||
// CHECK-DAG: %[[S1:.+]] = tensor.from_elements %[[C2]], %[[C3]], %[[D1]]
|
||||
// CHECK-DAG: %[[D1_OUT:.+]] = tensor.extract %[[S1]][%[[C2]]]
|
||||
// CHECK: return %[[D0_OUT]], %[[C5]], %[[C2]], %[[C3]], %[[D1_OUT]]
|
||||
|
||||
// -----
|
||||
|
||||
func @result_shape_per_dim(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
|
||||
-> (index, index, index, index, index) {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%c2 = constant 2 : index
|
||||
%0:2 = "test.op_with_result_shape_per_dim_interface"(%arg0, %arg1)
|
||||
: (tensor<2x3x?xf32>, tensor<?x5xf32>) -> (tensor<?x5xf32>, tensor<2x3x?xf32>)
|
||||
%1 = memref.dim %0#0, %c0 : tensor<?x5xf32>
|
||||
%2 = memref.dim %0#0, %c1 : tensor<?x5xf32>
|
||||
%3 = memref.dim %0#1, %c0 : tensor<2x3x?xf32>
|
||||
%4 = memref.dim %0#1, %c1 : tensor<2x3x?xf32>
|
||||
%5 = memref.dim %0#1, %c2 : tensor<2x3x?xf32>
|
||||
return %1, %2, %3, %4, %5 : index, index, index, index, index
|
||||
}
|
||||
// CHECK-LABEL: func @result_shape_per_dim(
|
||||
// CHECK-SAME: %[[ARG_0:[a-z0-9]*]]: tensor<2x3x?xf32>
|
||||
// CHECK-SAME: %[[ARG_1:[a-z0-9]*]]: tensor<?x5xf32>)
|
||||
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
|
||||
// CHECK-DAG: %[[C2:.+]] = constant 2 : index
|
||||
// CHECK-DAG: %[[C3:.+]] = constant 3 : index
|
||||
// CHECK-DAG: %[[C5:.+]] = constant 5 : index
|
||||
// CHECK-DAG: %[[D0:.+]] = memref.dim %[[ARG_1]], %[[C0]]
|
||||
// CHECK-DAG: %[[D1:.+]] = memref.dim %[[ARG_0]], %[[C2]]
|
||||
// CHECK: return %[[D0]], %[[C5]], %[[C2]], %[[C3]], %[[D1]]
|
||||
|
||||
// -----
|
||||
|
||||
func @result_shape_and_per_dim(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
|
||||
-> (index, index, index, index, index) {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%c2 = constant 2 : index
|
||||
%0:2 = "test.op_with_result_shape_and_per_dim_interface"(%arg0, %arg1)
|
||||
: (tensor<2x3x?xf32>, tensor<?x5xf32>) -> (tensor<?x5xf32>, tensor<2x3x?xf32>)
|
||||
%1 = memref.dim %0#0, %c0 : tensor<?x5xf32>
|
||||
%2 = memref.dim %0#0, %c1 : tensor<?x5xf32>
|
||||
%3 = memref.dim %0#1, %c0 : tensor<2x3x?xf32>
|
||||
%4 = memref.dim %0#1, %c1 : tensor<2x3x?xf32>
|
||||
%5 = memref.dim %0#1, %c2 : tensor<2x3x?xf32>
|
||||
return %1, %2, %3, %4, %5 : index, index, index, index, index
|
||||
}
|
||||
// CHECK-LABEL: func @result_shape_and_per_dim(
|
||||
// CHECK-SAME: %[[ARG_0:[a-z0-9]*]]: tensor<2x3x?xf32>
|
||||
// CHECK-SAME: %[[ARG_1:[a-z0-9]*]]: tensor<?x5xf32>)
|
||||
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
|
||||
// CHECK-DAG: %[[C2:.+]] = constant 2 : index
|
||||
// CHECK-DAG: %[[C3:.+]] = constant 3 : index
|
||||
// CHECK-DAG: %[[C5:.+]] = constant 5 : index
|
||||
// CHECK-DAG: %[[D0:.+]] = memref.dim %[[ARG_1]], %[[C0]]
|
||||
// CHECK-DAG: %[[S0:.+]] = tensor.from_elements %[[D0]], %[[C5]]
|
||||
// CHECK-DAG: %[[D0_OUT:.+]] = tensor.extract %[[S0]][%[[C0]]]
|
||||
// CHECK-DAG: %[[D1:.+]] = memref.dim %[[ARG_0]], %[[C2]]
|
||||
// CHECK-DAG: %[[S1:.+]] = tensor.from_elements %[[C2]], %[[C3]], %[[D1]]
|
||||
// CHECK-DAG: %[[D1_OUT:.+]] = tensor.extract %[[S1]][%[[C2]]]
|
||||
// CHECK: return %[[D0_OUT]], %[[C5]], %[[C2]], %[[C3]], %[[D1_OUT]]
|
|
@ -82,30 +82,6 @@ func @typemismatch() -> i32 {
|
|||
return %0 : i32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @result_shape_per_dim
|
||||
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<2x3x?xf32>, %[[ARG_1:[a-z0-9]*]]: tensor<?x5xf32>)
|
||||
func @result_shape_per_dim(%arg0 : tensor<2x3x?xf32>, %arg1 : tensor<?x5xf32>)
|
||||
-> (index, index, index, index, index) {
|
||||
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
|
||||
// CHECK-DAG: %[[C2:.+]] = constant 2 : index
|
||||
// CHECK-DAG: %[[C3:.+]] = constant 3 : index
|
||||
// CHECK-DAG: %[[C5:.+]] = constant 5 : index
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%c2 = constant 2 : index
|
||||
%0:2 = "test.op_with_result_shape_per_dim_interface"(%arg0, %arg1)
|
||||
: (tensor<2x3x?xf32>, tensor<?x5xf32>) -> (tensor<?x5xf32>, tensor<2x3x?xf32>)
|
||||
%1 = memref.dim %0#0, %c0 : tensor<?x5xf32>
|
||||
%2 = memref.dim %0#0, %c1 : tensor<?x5xf32>
|
||||
%3 = memref.dim %0#1, %c0 : tensor<2x3x?xf32>
|
||||
%4 = memref.dim %0#1, %c1 : tensor<2x3x?xf32>
|
||||
%5 = memref.dim %0#1, %c2 : tensor<2x3x?xf32>
|
||||
// CHECK-DAG: %[[D0:.+]] = memref.dim %[[ARG_1]], %[[C0]]
|
||||
// CHECK-DAG: %[[D1:.+]] = memref.dim %[[ARG_0]], %[[C2]]
|
||||
// CHECK: return %[[D0]], %[[C5]], %[[C2]], %[[C3]], %[[D1]]
|
||||
return %1, %2, %3, %4, %5 : index, index, index, index, index
|
||||
}
|
||||
|
||||
// CHECK-LABEL: test_dialect_canonicalizer
|
||||
func @test_dialect_canonicalizer() -> (i32) {
|
||||
%0 = "test.dialect_canonicalizable"() : () -> (i32)
|
||||
|
|
|
@ -65,6 +65,7 @@ add_mlir_library(MLIRTestDialect
|
|||
MLIRReduce
|
||||
MLIRStandard
|
||||
MLIRStandardOpsTransforms
|
||||
MLIRTensor
|
||||
MLIRTransformUtils
|
||||
MLIRTransforms
|
||||
)
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
#include "mlir/Dialect/DLTI/DLTI.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
@ -802,22 +803,75 @@ LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
|
|||
return success();
|
||||
}
|
||||
|
||||
LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
|
||||
OpBuilder &builder, ValueRange operands,
|
||||
llvm::SmallVectorImpl<Value> &shapes) {
|
||||
Location loc = getLoc();
|
||||
shapes.reserve(operands.size());
|
||||
for (Value operand : llvm::reverse(operands)) {
|
||||
auto currShape = llvm::to_vector<4>(llvm::map_range(
|
||||
llvm::seq<int64_t>(
|
||||
0, operand.getType().cast<RankedTensorType>().getRank()),
|
||||
[&](int64_t dim) -> Value {
|
||||
return builder.createOrFold<memref::DimOp>(loc, operand, dim);
|
||||
}));
|
||||
shapes.push_back(builder.create<tensor::FromElementsOp>(
|
||||
getLoc(), builder.getIndexType(), currShape));
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
OpWithResultShapePerDimInterfaceOp ::reifyReturnTypeShapesPerResultDim(
|
||||
OpBuilder &builder,
|
||||
llvm::SmallVectorImpl<llvm::SmallVector<Value>> &shapes) {
|
||||
SmallVector<Value> operand1Shape, operand2Shape;
|
||||
Location loc = getLoc();
|
||||
for (auto i :
|
||||
llvm::seq<int>(0, operand1().getType().cast<ShapedType>().getRank())) {
|
||||
operand1Shape.push_back(builder.create<memref::DimOp>(loc, operand1(), i));
|
||||
shapes.reserve(getNumOperands());
|
||||
for (Value operand : llvm::reverse(getOperands())) {
|
||||
auto currShape = llvm::to_vector<4>(llvm::map_range(
|
||||
llvm::seq<int64_t>(
|
||||
0, operand.getType().cast<RankedTensorType>().getRank()),
|
||||
[&](int64_t dim) -> Value {
|
||||
return builder.createOrFold<memref::DimOp>(loc, operand, dim);
|
||||
}));
|
||||
shapes.emplace_back(std::move(currShape));
|
||||
}
|
||||
for (auto i :
|
||||
llvm::seq<int>(0, operand2().getType().cast<ShapedType>().getRank())) {
|
||||
operand2Shape.push_back(builder.create<memref::DimOp>(loc, operand2(), i));
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult OpWithResultShapeAndPerDimInterfaceOp::reifyReturnTypeShapes(
|
||||
OpBuilder &builder, ValueRange operands,
|
||||
llvm::SmallVectorImpl<Value> &shapes) {
|
||||
Location loc = getLoc();
|
||||
shapes.reserve(operands.size());
|
||||
for (Value operand : llvm::reverse(operands)) {
|
||||
auto currShape = llvm::to_vector<4>(llvm::map_range(
|
||||
llvm::seq<int64_t>(
|
||||
0, operand.getType().cast<RankedTensorType>().getRank()),
|
||||
[&](int64_t dim) -> Value {
|
||||
return builder.createOrFold<memref::DimOp>(loc, operand, dim);
|
||||
}));
|
||||
shapes.push_back(builder.create<tensor::FromElementsOp>(
|
||||
getLoc(), builder.getIndexType(), currShape));
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
OpWithResultShapeAndPerDimInterfaceOp ::reifyReturnTypeShapesPerResultDim(
|
||||
OpBuilder &builder,
|
||||
llvm::SmallVectorImpl<llvm::SmallVector<Value>> &shapes) {
|
||||
Location loc = getLoc();
|
||||
shapes.reserve(getNumOperands());
|
||||
for (Value operand : llvm::reverse(getOperands())) {
|
||||
auto currShape = llvm::to_vector<4>(llvm::map_range(
|
||||
llvm::seq<int64_t>(
|
||||
0, operand.getType().cast<RankedTensorType>().getRank()),
|
||||
[&](int64_t dim) -> Value {
|
||||
return builder.createOrFold<memref::DimOp>(loc, operand, dim);
|
||||
}));
|
||||
shapes.emplace_back(std::move(currShape));
|
||||
}
|
||||
shapes.emplace_back(std::move(operand2Shape));
|
||||
shapes.emplace_back(std::move(operand1Shape));
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
|
@ -571,9 +571,25 @@ def OpWithShapedTypeInferTypeInterfaceOp : TEST_Op<"op_with_shaped_type_infer_ty
|
|||
let results = (outs AnyTensor);
|
||||
}
|
||||
|
||||
def OpWithResultShapePerDimInterfaceOp : TEST_Op<"op_with_result_shape_per_dim_interface",
|
||||
def OpWithResultShapeInterfaceOp : TEST_Op<"op_with_result_shape_interface",
|
||||
[DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
|
||||
["reifyReturnTypeShapesPerResultDim"]>]> {
|
||||
["reifyReturnTypeShapes"]>]> {
|
||||
let arguments = (ins AnyRankedTensor:$operand1, AnyRankedTensor:$operand2);
|
||||
let results = (outs AnyRankedTensor:$result1, AnyRankedTensor:$result2);
|
||||
}
|
||||
|
||||
def OpWithResultShapePerDimInterfaceOp :
|
||||
TEST_Op<"op_with_result_shape_per_dim_interface",
|
||||
[DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
|
||||
["reifyReturnTypeShapesPerResultDim"]>]> {
|
||||
let arguments = (ins AnyRankedTensor:$operand1, AnyRankedTensor:$operand2);
|
||||
let results = (outs AnyRankedTensor:$result1, AnyRankedTensor:$result2);
|
||||
}
|
||||
|
||||
def OpWithResultShapeAndPerDimInterfaceOp :
|
||||
TEST_Op<"op_with_result_shape_and_per_dim_interface",
|
||||
[DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
|
||||
["reifyReturnTypeShapes", "reifyReturnTypeShapesPerResultDim"]>]> {
|
||||
let arguments = (ins AnyRankedTensor:$operand1, AnyRankedTensor:$operand2);
|
||||
let results = (outs AnyRankedTensor:$result1, AnyRankedTensor:$result2);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue