[mlir][Linalg] Add folders and canonicalizers for

linalg.reshape/linalg.tensor_reshape operations.

Differential Revision: https://reviews.llvm.org/D79765
This commit is contained in:
MaheshRavishankar 2020-05-12 22:50:44 -07:00
parent d2a9569850
commit 5440d0a12d
4 changed files with 349 additions and 65 deletions

View File

@ -77,6 +77,10 @@ class Linalg_ReshapeLikeOp<string mnemonic> :
code commonExtraClassDeclaration = [{
static StringRef getReassociationAttrName() { return "reassociation"; }
SmallVector<AffineMap, 4> getReassociationMaps() {
return llvm::to_vector<4>(llvm::map_range(reassociation(), [
](Attribute a) { return a.cast<AffineMapAttr>().getValue(); }));
}
}];
let assemblyFormat = [{
$src $reassociation attr-dict `:` type($src) `into` type(results)
@ -137,6 +141,7 @@ def Linalg_ReshapeOp : Linalg_ReshapeLikeOp<"reshape">,
MemRefType getResultType() { return result().getType().cast<MemRefType>(); }
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
}
def Linalg_TensorReshapeOp : Linalg_ReshapeLikeOp<"tensor_reshape">,
@ -187,11 +192,9 @@ def Linalg_TensorReshapeOp : Linalg_ReshapeLikeOp<"tensor_reshape">,
RankedTensorType getResultType() {
return result().getType().cast<RankedTensorType>();
}
SmallVector<AffineMap, 4> getReassociationMaps() {
return llvm::to_vector<4>(llvm::map_range(reassociation(),
[](Attribute a) { return a.cast<AffineMapAttr>().getValue(); }));
}
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
}
def Linalg_SliceOp : Linalg_Op<"slice", [

View File

@ -246,6 +246,108 @@ static LogicalResult verify(IndexedGenericOp op) { return verifyGenericOp(op); }
// ReshapeOp
//===----------------------------------------------------------------------===//
/// Collapse reassociation maps that are used in pair of reshape ops where one
/// is a producer and other is the consumer. Only valid to use this method when
/// both the producer and consumer are collapsing dimensions or both are
/// expanding dimensions.
///
/// For example,
/// mapsProducer = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>,
/// affine_map<(d0, d1, d2, d3, d4) -> (d2)>,
/// affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>]
/// mapsConsumer = [affine_map<(d0, d1, d2) -> (d0, d1)>,
/// affine_map<(d0, d1, d2) -> (d2)>]
///
/// is folded into
///
/// result = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
/// affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>]
static ArrayAttr collapseReassociationMaps(ArrayRef<AffineMap> mapsProducer,
ArrayRef<AffineMap> mapsConsumer,
MLIRContext *context) {
if (mapsProducer.size() == 0 || mapsConsumer.size() == 0 ||
mapsProducer[0].getNumDims() < mapsConsumer[0].getNumDims() ||
mapsProducer.size() != mapsConsumer[0].getNumDims())
return nullptr;
unsigned numLhsDims = mapsProducer[0].getNumDims();
unsigned currDim = 0;
SmallVector<AffineExpr, 4> reassociations;
SmallVector<Attribute, 4> reassociationMaps;
for (AffineMap rhs : mapsConsumer) {
for (AffineExpr rhsExpr : rhs.getResults()) {
AffineDimExpr dimExpr = rhsExpr.cast<AffineDimExpr>();
for (int i = 0, e = mapsProducer[dimExpr.getPosition()].getNumResults();
i != e; ++i) {
reassociations.push_back(getAffineDimExpr(currDim++, context));
}
}
reassociationMaps.push_back(AffineMapAttr::get(AffineMap::get(
numLhsDims, /*numSymbols =*/0, reassociations, context)));
reassociations.clear();
}
return ArrayAttr::get(reassociationMaps, context);
}
namespace {
/// Pattern to collapse producer/consumer reshape ops that are both collapsing
/// dimensions or are both expanding dimensions.
template <typename ReshapeOpTy>
struct CollapseReshapeOps : public OpRewritePattern<ReshapeOpTy> {
using OpRewritePattern<ReshapeOpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp,
PatternRewriter &rewriter) const override {
auto srcReshapeOp =
dyn_cast_or_null<ReshapeOpTy>(reshapeOp.src().getDefiningOp());
if (!srcReshapeOp)
return failure();
auto areReshapeOpsFoldable = [](ShapedType largerType,
ShapedType intermediateType,
ShapedType smallerType) -> bool {
return largerType.getRank() > intermediateType.getRank() &&
intermediateType.getRank() > smallerType.getRank() &&
smallerType.getRank() > 0;
};
// Check if producer and consumer are both expanding dims.
if (areReshapeOpsFoldable(reshapeOp.getResultType(), reshapeOp.getSrcType(),
srcReshapeOp.getSrcType())) {
rewriter.replaceOpWithNewOp<ReshapeOpTy>(
reshapeOp, reshapeOp.getResultType(), srcReshapeOp.src(),
collapseReassociationMaps(reshapeOp.getReassociationMaps(),
srcReshapeOp.getReassociationMaps(),
rewriter.getContext()));
return success();
}
// Check if producer and consumer are both collapsing dims.
else if (areReshapeOpsFoldable(srcReshapeOp.getSrcType(),
reshapeOp.getSrcType(),
reshapeOp.getResultType())) {
rewriter.replaceOpWithNewOp<ReshapeOpTy>(
reshapeOp, reshapeOp.getResultType(), srcReshapeOp.src(),
collapseReassociationMaps(srcReshapeOp.getReassociationMaps(),
reshapeOp.getReassociationMaps(),
rewriter.getContext()));
return success();
}
return failure();
}
};
} // namespace
template <typename ReshapeOpTy>
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp) {
// Fold producer-consumer reshape ops that where the operand type of the
// producer is same as the return type of the consumer. This can only be
// verified if the shapes in question are static.
ReshapeOpTy reshapeSrcOp =
dyn_cast_or_null<ReshapeOpTy>(reshapeOp.src().getDefiningOp());
if (reshapeSrcOp && reshapeSrcOp.getSrcType().hasStaticShape() &&
reshapeOp.getResultType().hasStaticShape() &&
reshapeSrcOp.getSrcType() == reshapeOp.getResultType())
return reshapeSrcOp.src();
return nullptr;
};
/// Return true if the reassociation specification is valid, false otherwise.
/// When false, the `invalidIndex` integer pointer is optionally filled with the
/// index of the offending reassociation map.
@ -482,6 +584,11 @@ static LogicalResult verify(ReshapeOp op) {
return success();
}
void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<CollapseReshapeOps<ReshapeOp>>(context);
}
//===----------------------------------------------------------------------===//
// TensorReshapeOp
//===----------------------------------------------------------------------===//
@ -551,6 +658,11 @@ static LogicalResult verify(TensorReshapeOp op) {
return success();
}
void TensorReshapeOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<CollapseReshapeOps<TensorReshapeOp>>(context);
}
//===----------------------------------------------------------------------===//
// SliceOp
//===----------------------------------------------------------------------===//
@ -1010,13 +1122,18 @@ LogicalResult MatmulOp::fold(ArrayRef<Attribute>,
OpFoldResult ReshapeOp::fold(ArrayRef<Attribute>) {
if (succeeded(foldMemRefCast(*this)))
return getResult();
return {};
return foldReshapeOp(*this);
}
OpFoldResult SliceOp::fold(ArrayRef<Attribute>) {
if (succeeded(foldMemRefCast(*this)))
return getResult();
return {};
}
OpFoldResult TensorReshapeOp::fold(ArrayRef<Attribute>) {
if (succeeded(foldMemRefCast(*this)))
return getResult();
return foldReshapeOp(*this);
}
OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) {
if (succeeded(foldMemRefCast(*this)))
return getResult();

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt %s -canonicalize | FileCheck %s
// RUN: mlir-opt %s -canonicalize -split-input-file | FileCheck %s
// CHECK-LABEL: func @memref_cast(
func @memref_cast(%a: index, %b: index) -> memref<?x?xf32> {
@ -18,3 +18,157 @@ func @memref_cast(%a: index, %b: index) -> memref<?x?xf32> {
linalg.matmul(%3, %3, %3) : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
return %4: memref<?x?xf32>
}
// -----
func @collapsing_tensor_reshapes(%arg0 : tensor<?x?x?x?x?xf32>) -> tensor<?x?xf32>
{
%0 = linalg.tensor_reshape %arg0
[affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>,
affine_map<(d0, d1, d2, d3, d4) -> (d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>] :
tensor<?x?x?x?x?xf32> into tensor<?x?x?xf32>
%1 = linalg.tensor_reshape %0
[affine_map<(d0, d1, d2) -> (d0, d1)>,
affine_map<(d0, d1, d2) -> (d2)>] :
tensor<?x?x?xf32> into tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>
// CHECK-LABEL: collapsing_tensor_reshapes
// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]], #[[MAP1]]]
// CHECK-NOT: linalg.tensor_reshape
// -----
func @expanding_tensor_reshapes(%arg0 : tensor<?x?xf32>) -> tensor<?x?x?x?x?xf32>
{
%0 = linalg.tensor_reshape %arg0
[affine_map<(d0, d1, d2) -> (d0, d1)>,
affine_map<(d0, d1, d2) -> (d2)>] :
tensor<?x?xf32> into tensor<?x?x?xf32>
%1 = linalg.tensor_reshape %0
[affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>,
affine_map<(d0, d1, d2, d3, d4) -> (d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>] :
tensor<?x?x?xf32> into tensor<?x?x?x?x?xf32>
return %1 : tensor<?x?x?x?x?xf32>
}
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>
// CHECK-LABEL: expanding_tensor_reshapes
// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]], #[[MAP1]]]
// CHECK-NOT: linalg.tensor_reshape
// -----
func @collapsing_memref_reshapes(%arg0 : memref<?x?x?x?x?xf32>) -> memref<?x?xf32>
{
%0 = linalg.reshape %arg0
[affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>,
affine_map<(d0, d1, d2, d3, d4) -> (d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>] :
memref<?x?x?x?x?xf32> into memref<?x?x?xf32>
%1 = linalg.reshape %0
[affine_map<(d0, d1, d2) -> (d0, d1)>,
affine_map<(d0, d1, d2) -> (d2)>] :
memref<?x?x?xf32> into memref<?x?xf32>
return %1 : memref<?x?xf32>
}
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>
// CHECK-LABEL: collapsing_memref_reshapes
// CHECK: linalg.reshape %{{.*}} [#[[MAP0]], #[[MAP1]]]
// CHECK-NOT: linalg.reshape
// -----
func @expanding_memref_reshapes(%arg0 : memref<?x?xf32>) -> memref<?x?x?x?x?xf32>
{
%0 = linalg.reshape %arg0
[affine_map<(d0, d1, d2) -> (d0, d1)>,
affine_map<(d0, d1, d2) -> (d2)>] :
memref<?x?xf32> into memref<?x?x?xf32>
%1 = linalg.reshape %0
[affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>,
affine_map<(d0, d1, d2, d3, d4) -> (d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>] :
memref<?x?x?xf32> into memref<?x?x?x?x?xf32>
return %1 : memref<?x?x?x?x?xf32>
}
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>
// CHECK-LABEL: expanding_memref_reshapes
// CHECK: linalg.reshape %{{.*}} [#[[MAP0]], #[[MAP1]]]
// CHECK-NOT: linalg.reshape
// -----
func @fold_tensor_reshape(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32>
{
%0 = linalg.tensor_reshape %arg0
[affine_map<(d0, d1, d2) -> (d0, d1)>,
affine_map<(d0, d1, d2) -> (d2)>] :
tensor<12x4xf32> into tensor<3x4x4xf32>
%1 = linalg.tensor_reshape %0
[affine_map<(d0, d1, d2) -> (d0, d1)>,
affine_map<(d0, d1, d2) -> (d2)>] :
tensor<3x4x4xf32> into tensor<12x4xf32>
return %1 : tensor<12x4xf32>
}
// CHECK-LABEL: @fold_tensor_reshape
// CHECK-NOT: linalg.tensor_reshape
// -----
func @no_fold_tensor_reshape(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32>
{
%0 = linalg.tensor_reshape %arg0
[affine_map<(d0, d1, d2) -> (d0, d1)>,
affine_map<(d0, d1, d2) -> (d2)>] :
tensor<?x?xf32> into tensor<?x?x?xf32>
%1 = linalg.tensor_reshape %0
[affine_map<(d0, d1, d2) -> (d0, d1)>,
affine_map<(d0, d1, d2) -> (d2)>] :
tensor<?x?x?xf32> into tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
// CHECK-LABEL: @no_fold_tensor_reshape
// CHECK: linalg.tensor_reshape
// CHECK: linalg.tensor_reshape
// -----
func @fold_memref_reshape(%arg0 : memref<12x4xf32>) -> memref<12x4xf32>
{
%0 = linalg.reshape %arg0
[affine_map<(d0, d1, d2) -> (d0, d1)>,
affine_map<(d0, d1, d2) -> (d2)>] :
memref<12x4xf32> into memref<3x4x4xf32>
%1 = linalg.reshape %0
[affine_map<(d0, d1, d2) -> (d0, d1)>,
affine_map<(d0, d1, d2) -> (d2)>] :
memref<3x4x4xf32> into memref<12x4xf32>
return %1 : memref<12x4xf32>
}
// CHECK-LABEL: @fold_memref_reshape
// CHECK-NOT: linalg.reshape
// -----
func @no_fold_memref_reshape(%arg0 : memref<?x?xf32>) -> memref<?x?xf32>
{
%0 = linalg.reshape %arg0
[affine_map<(d0, d1, d2) -> (d0, d1)>,
affine_map<(d0, d1, d2) -> (d2)>] :
memref<?x?xf32> into memref<?x?x?xf32>
%1 = linalg.reshape %0
[affine_map<(d0, d1, d2) -> (d0, d1)>,
affine_map<(d0, d1, d2) -> (d2)>] :
memref<?x?x?xf32> into memref<?x?xf32>
return %1 : memref<?x?xf32>
}
// CHECK-LABEL: @no_fold_memref_reshape
// CHECK: linalg.reshape
// CHECK: linalg.reshape

View File

@ -214,72 +214,76 @@ func @matmul_vec_indexed(%A: !matrix_type_A,
// CHECK-SAME: !llvm<"<4 x float>*">, !llvm<"<4 x float>*">, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64
// CHECK-SAME: !llvm<"[4 x <4 x float>]*">, !llvm<"[4 x <4 x float>]*">, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64
func @reshape_static(%arg0: memref<3x4x5xf32>) {
// Reshapes that expand and collapse back a contiguous tensor with some 1's.
func @reshape_static_expand(%arg0: memref<3x4x5xf32>) -> memref<1x3x4x1x5xf32> {
// Reshapes that expand a contiguous tensor with some 1's.
%0 = linalg.reshape %arg0 [affine_map<(i, j, k, l, m) -> (i, j)>,
affine_map<(i, j, k, l, m) -> (k)>,
affine_map<(i, j, k, l, m) -> (l, m)>] :
memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
%r0 = linalg.reshape %0 [affine_map<(i, j, k, l, m) -> (i, j)>,
affine_map<(i, j, k, l, m) -> (k)>,
affine_map<(i, j, k, l, m) -> (l, m)>] :
memref<1x3x4x1x5xf32> into memref<3x4x5xf32>
return
return %0 : memref<1x3x4x1x5xf32>
}
// CHECK-LABEL: func @reshape_static(
// CHECK: llvm.mlir.undef : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
// CHECK: llvm.insertvalue {{.*}}[0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
// CHECK: llvm.extractvalue {{.*}}[1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
// CHECK: llvm.insertvalue {{.*}}[1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
// CHECK: llvm.extractvalue {{.*}}[2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
// CHECK: llvm.insertvalue {{.*}}[2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK: llvm.insertvalue {{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
// CHECK: llvm.mlir.constant(3 : index) : !llvm.i64
// CHECK: llvm.insertvalue {{.*}}[3, 1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
// CHECK: llvm.mlir.constant(4 : index) : !llvm.i64
// CHECK: llvm.insertvalue {{.*}}[3, 2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK: llvm.insertvalue {{.*}}[3, 3] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
// CHECK: llvm.mlir.constant(5 : index) : !llvm.i64
// CHECK: llvm.insertvalue {{.*}}[3, 4] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
// CHECK: llvm.mlir.constant(60 : index) : !llvm.i64
// CHECK: llvm.insertvalue {{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
// CHECK: llvm.mlir.constant(20 : index) : !llvm.i64
// CHECK: llvm.insertvalue {{.*}}[4, 1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
// CHECK: llvm.mlir.constant(5 : index) : !llvm.i64
// CHECK: llvm.insertvalue {{.*}}[4, 2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
// CHECK: llvm.mlir.constant(5 : index) : !llvm.i64
// CHECK: llvm.insertvalue {{.*}}[4, 3] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK: llvm.insertvalue {{.*}}[4, 4] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
// CHECK: llvm.mlir.undef : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
// CHECK: llvm.insertvalue {{.*}}[0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
// CHECK: llvm.extractvalue {{.*}}[1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
// CHECK: llvm.insertvalue {{.*}}[1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
// CHECK: llvm.extractvalue {{.*}}[2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
// CHECK: llvm.insertvalue {{.*}}[2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
// CHECK: llvm.mlir.constant(3 : index) : !llvm.i64
// CHECK: llvm.insertvalue {{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
// CHECK: llvm.mlir.constant(4 : index) : !llvm.i64
// CHECK: llvm.insertvalue {{.*}}[3, 1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
// CHECK: llvm.mlir.constant(5 : index) : !llvm.i64
// CHECK: llvm.insertvalue {{.*}}[3, 2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
// CHECK: llvm.mlir.constant(20 : index) : !llvm.i64
// CHECK: llvm.insertvalue {{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
// CHECK: llvm.mlir.constant(5 : index) : !llvm.i64
// CHECK: llvm.insertvalue {{.*}}[4, 1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK: llvm.insertvalue {{.*}}[4, 2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
// CHECK-LABEL: func @reshape_static_expand
// CHECK: llvm.mlir.undef : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
// CHECK: llvm.mlir.constant(3 : index) : !llvm.i64
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
// CHECK: llvm.mlir.constant(4 : index) : !llvm.i64
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 3] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
// CHECK: llvm.mlir.constant(5 : index) : !llvm.i64
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 4] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
// CHECK: llvm.mlir.constant(60 : index) : !llvm.i64
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
// CHECK: llvm.mlir.constant(20 : index) : !llvm.i64
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
// CHECK: llvm.mlir.constant(5 : index) : !llvm.i64
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
// CHECK: llvm.mlir.constant(5 : index) : !llvm.i64
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 3] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 4] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
func @reshape_zero_dim(%arg0 : memref<1x1xf32>) {
%0 = linalg.reshape %arg0 [] : memref<1x1xf32> into memref<f32>
%1 = linalg.reshape %0 [] : memref<f32> into memref<1x1xf32>
return
func @reshape_static_collapse(%arg0: memref<1x3x4x1x5xf32>) -> memref<3x4x5xf32> {
%0 = linalg.reshape %arg0 [affine_map<(i, j, k, l, m) -> (i, j)>,
affine_map<(i, j, k, l, m) -> (k)>,
affine_map<(i, j, k, l, m) -> (l, m)>] :
memref<1x3x4x1x5xf32> into memref<3x4x5xf32>
return %0 : memref<3x4x5xf32>
}
// CHECK-LABEL: func @reshape_zero_dim
// CHECK-LABEL: func @reshape_static_collapse
// CHECK: llvm.mlir.undef : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
// CHECK: llvm.mlir.constant(3 : index) : !llvm.i64
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
// CHECK: llvm.mlir.constant(4 : index) : !llvm.i64
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
// CHECK: llvm.mlir.constant(5 : index) : !llvm.i64
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
// CHECK: llvm.mlir.constant(20 : index) : !llvm.i64
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
// CHECK: llvm.mlir.constant(5 : index) : !llvm.i64
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
func @reshape_fold_zero_dim(%arg0 : memref<1x1xf32>) -> memref<f32> {
%0 = linalg.reshape %arg0 [] : memref<1x1xf32> into memref<f32>
return %0 : memref<f32>
}
// CHECK-LABEL: func @reshape_fold_zero_dim
// CHECK: llvm.mlir.undef : !llvm<"{ float*, float*, i64 }">
// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, float*, i64 }">
@ -287,6 +291,12 @@ func @reshape_zero_dim(%arg0 : memref<1x1xf32>) {
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ float*, float*, i64 }">
// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm<"{ float*, float*, i64 }">
func @reshape_expand_zero_dim(%arg0 : memref<f32>) -> memref<1x1xf32> {
%0 = linalg.reshape %arg0 [] : memref<f32> into memref<1x1xf32>
return %0 : memref<1x1xf32>
}
// CHECK-LABEL: func @reshape_expand_zero_dim
// CHECK: llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, float*, i64 }">
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">