forked from OSchip/llvm-project
[mlir][Linalg] Add folders and canonicalizers for
linalg.reshape/linalg.tensor_reshape operations. Differential Revision: https://reviews.llvm.org/D79765
This commit is contained in:
parent
d2a9569850
commit
5440d0a12d
|
@ -77,6 +77,10 @@ class Linalg_ReshapeLikeOp<string mnemonic> :
|
|||
|
||||
code commonExtraClassDeclaration = [{
|
||||
static StringRef getReassociationAttrName() { return "reassociation"; }
|
||||
SmallVector<AffineMap, 4> getReassociationMaps() {
|
||||
return llvm::to_vector<4>(llvm::map_range(reassociation(), [
|
||||
](Attribute a) { return a.cast<AffineMapAttr>().getValue(); }));
|
||||
}
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
$src $reassociation attr-dict `:` type($src) `into` type(results)
|
||||
|
@ -137,6 +141,7 @@ def Linalg_ReshapeOp : Linalg_ReshapeLikeOp<"reshape">,
|
|||
MemRefType getResultType() { return result().getType().cast<MemRefType>(); }
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Linalg_TensorReshapeOp : Linalg_ReshapeLikeOp<"tensor_reshape">,
|
||||
|
@ -187,11 +192,9 @@ def Linalg_TensorReshapeOp : Linalg_ReshapeLikeOp<"tensor_reshape">,
|
|||
RankedTensorType getResultType() {
|
||||
return result().getType().cast<RankedTensorType>();
|
||||
}
|
||||
SmallVector<AffineMap, 4> getReassociationMaps() {
|
||||
return llvm::to_vector<4>(llvm::map_range(reassociation(),
|
||||
[](Attribute a) { return a.cast<AffineMapAttr>().getValue(); }));
|
||||
}
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Linalg_SliceOp : Linalg_Op<"slice", [
|
||||
|
|
|
@ -246,6 +246,108 @@ static LogicalResult verify(IndexedGenericOp op) { return verifyGenericOp(op); }
|
|||
// ReshapeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Collapse reassociation maps that are used in pair of reshape ops where one
|
||||
/// is a producer and other is the consumer. Only valid to use this method when
|
||||
/// both the producer and consumer are collapsing dimensions or both are
|
||||
/// expanding dimensions.
|
||||
///
|
||||
/// For example,
|
||||
/// mapsProducer = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>,
|
||||
/// affine_map<(d0, d1, d2, d3, d4) -> (d2)>,
|
||||
/// affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>]
|
||||
/// mapsConsumer = [affine_map<(d0, d1, d2) -> (d0, d1)>,
|
||||
/// affine_map<(d0, d1, d2) -> (d2)>]
|
||||
///
|
||||
/// is folded into
|
||||
///
|
||||
/// result = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
|
||||
/// affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>]
|
||||
static ArrayAttr collapseReassociationMaps(ArrayRef<AffineMap> mapsProducer,
|
||||
ArrayRef<AffineMap> mapsConsumer,
|
||||
MLIRContext *context) {
|
||||
if (mapsProducer.size() == 0 || mapsConsumer.size() == 0 ||
|
||||
mapsProducer[0].getNumDims() < mapsConsumer[0].getNumDims() ||
|
||||
mapsProducer.size() != mapsConsumer[0].getNumDims())
|
||||
return nullptr;
|
||||
unsigned numLhsDims = mapsProducer[0].getNumDims();
|
||||
unsigned currDim = 0;
|
||||
SmallVector<AffineExpr, 4> reassociations;
|
||||
SmallVector<Attribute, 4> reassociationMaps;
|
||||
for (AffineMap rhs : mapsConsumer) {
|
||||
for (AffineExpr rhsExpr : rhs.getResults()) {
|
||||
AffineDimExpr dimExpr = rhsExpr.cast<AffineDimExpr>();
|
||||
for (int i = 0, e = mapsProducer[dimExpr.getPosition()].getNumResults();
|
||||
i != e; ++i) {
|
||||
reassociations.push_back(getAffineDimExpr(currDim++, context));
|
||||
}
|
||||
}
|
||||
reassociationMaps.push_back(AffineMapAttr::get(AffineMap::get(
|
||||
numLhsDims, /*numSymbols =*/0, reassociations, context)));
|
||||
reassociations.clear();
|
||||
}
|
||||
return ArrayAttr::get(reassociationMaps, context);
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// Pattern to collapse producer/consumer reshape ops that are both collapsing
|
||||
/// dimensions or are both expanding dimensions.
|
||||
template <typename ReshapeOpTy>
|
||||
struct CollapseReshapeOps : public OpRewritePattern<ReshapeOpTy> {
|
||||
using OpRewritePattern<ReshapeOpTy>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto srcReshapeOp =
|
||||
dyn_cast_or_null<ReshapeOpTy>(reshapeOp.src().getDefiningOp());
|
||||
if (!srcReshapeOp)
|
||||
return failure();
|
||||
|
||||
auto areReshapeOpsFoldable = [](ShapedType largerType,
|
||||
ShapedType intermediateType,
|
||||
ShapedType smallerType) -> bool {
|
||||
return largerType.getRank() > intermediateType.getRank() &&
|
||||
intermediateType.getRank() > smallerType.getRank() &&
|
||||
smallerType.getRank() > 0;
|
||||
};
|
||||
// Check if producer and consumer are both expanding dims.
|
||||
if (areReshapeOpsFoldable(reshapeOp.getResultType(), reshapeOp.getSrcType(),
|
||||
srcReshapeOp.getSrcType())) {
|
||||
rewriter.replaceOpWithNewOp<ReshapeOpTy>(
|
||||
reshapeOp, reshapeOp.getResultType(), srcReshapeOp.src(),
|
||||
collapseReassociationMaps(reshapeOp.getReassociationMaps(),
|
||||
srcReshapeOp.getReassociationMaps(),
|
||||
rewriter.getContext()));
|
||||
return success();
|
||||
}
|
||||
// Check if producer and consumer are both collapsing dims.
|
||||
else if (areReshapeOpsFoldable(srcReshapeOp.getSrcType(),
|
||||
reshapeOp.getSrcType(),
|
||||
reshapeOp.getResultType())) {
|
||||
rewriter.replaceOpWithNewOp<ReshapeOpTy>(
|
||||
reshapeOp, reshapeOp.getResultType(), srcReshapeOp.src(),
|
||||
collapseReassociationMaps(srcReshapeOp.getReassociationMaps(),
|
||||
reshapeOp.getReassociationMaps(),
|
||||
rewriter.getContext()));
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
template <typename ReshapeOpTy>
|
||||
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp) {
|
||||
// Fold producer-consumer reshape ops that where the operand type of the
|
||||
// producer is same as the return type of the consumer. This can only be
|
||||
// verified if the shapes in question are static.
|
||||
ReshapeOpTy reshapeSrcOp =
|
||||
dyn_cast_or_null<ReshapeOpTy>(reshapeOp.src().getDefiningOp());
|
||||
if (reshapeSrcOp && reshapeSrcOp.getSrcType().hasStaticShape() &&
|
||||
reshapeOp.getResultType().hasStaticShape() &&
|
||||
reshapeSrcOp.getSrcType() == reshapeOp.getResultType())
|
||||
return reshapeSrcOp.src();
|
||||
return nullptr;
|
||||
};
|
||||
|
||||
/// Return true if the reassociation specification is valid, false otherwise.
|
||||
/// When false, the `invalidIndex` integer pointer is optionally filled with the
|
||||
/// index of the offending reassociation map.
|
||||
|
@ -482,6 +584,11 @@ static LogicalResult verify(ReshapeOp op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.insert<CollapseReshapeOps<ReshapeOp>>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorReshapeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -551,6 +658,11 @@ static LogicalResult verify(TensorReshapeOp op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
void TensorReshapeOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
results.insert<CollapseReshapeOps<TensorReshapeOp>>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SliceOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1010,13 +1122,18 @@ LogicalResult MatmulOp::fold(ArrayRef<Attribute>,
|
|||
OpFoldResult ReshapeOp::fold(ArrayRef<Attribute>) {
|
||||
if (succeeded(foldMemRefCast(*this)))
|
||||
return getResult();
|
||||
return {};
|
||||
return foldReshapeOp(*this);
|
||||
}
|
||||
OpFoldResult SliceOp::fold(ArrayRef<Attribute>) {
|
||||
if (succeeded(foldMemRefCast(*this)))
|
||||
return getResult();
|
||||
return {};
|
||||
}
|
||||
OpFoldResult TensorReshapeOp::fold(ArrayRef<Attribute>) {
|
||||
if (succeeded(foldMemRefCast(*this)))
|
||||
return getResult();
|
||||
return foldReshapeOp(*this);
|
||||
}
|
||||
OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) {
|
||||
if (succeeded(foldMemRefCast(*this)))
|
||||
return getResult();
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: mlir-opt %s -canonicalize | FileCheck %s
|
||||
// RUN: mlir-opt %s -canonicalize -split-input-file | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @memref_cast(
|
||||
func @memref_cast(%a: index, %b: index) -> memref<?x?xf32> {
|
||||
|
@ -18,3 +18,157 @@ func @memref_cast(%a: index, %b: index) -> memref<?x?xf32> {
|
|||
linalg.matmul(%3, %3, %3) : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
|
||||
return %4: memref<?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @collapsing_tensor_reshapes(%arg0 : tensor<?x?x?x?x?xf32>) -> tensor<?x?xf32>
|
||||
{
|
||||
%0 = linalg.tensor_reshape %arg0
|
||||
[affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>,
|
||||
affine_map<(d0, d1, d2, d3, d4) -> (d2)>,
|
||||
affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>] :
|
||||
tensor<?x?x?x?x?xf32> into tensor<?x?x?xf32>
|
||||
%1 = linalg.tensor_reshape %0
|
||||
[affine_map<(d0, d1, d2) -> (d0, d1)>,
|
||||
affine_map<(d0, d1, d2) -> (d2)>] :
|
||||
tensor<?x?x?xf32> into tensor<?x?xf32>
|
||||
return %1 : tensor<?x?xf32>
|
||||
}
|
||||
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
|
||||
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>
|
||||
// CHECK-LABEL: collapsing_tensor_reshapes
|
||||
// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]], #[[MAP1]]]
|
||||
// CHECK-NOT: linalg.tensor_reshape
|
||||
|
||||
// -----
|
||||
|
||||
func @expanding_tensor_reshapes(%arg0 : tensor<?x?xf32>) -> tensor<?x?x?x?x?xf32>
|
||||
{
|
||||
%0 = linalg.tensor_reshape %arg0
|
||||
[affine_map<(d0, d1, d2) -> (d0, d1)>,
|
||||
affine_map<(d0, d1, d2) -> (d2)>] :
|
||||
tensor<?x?xf32> into tensor<?x?x?xf32>
|
||||
%1 = linalg.tensor_reshape %0
|
||||
[affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>,
|
||||
affine_map<(d0, d1, d2, d3, d4) -> (d2)>,
|
||||
affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>] :
|
||||
tensor<?x?x?xf32> into tensor<?x?x?x?x?xf32>
|
||||
return %1 : tensor<?x?x?x?x?xf32>
|
||||
}
|
||||
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
|
||||
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>
|
||||
// CHECK-LABEL: expanding_tensor_reshapes
|
||||
// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]], #[[MAP1]]]
|
||||
// CHECK-NOT: linalg.tensor_reshape
|
||||
|
||||
// -----
|
||||
|
||||
func @collapsing_memref_reshapes(%arg0 : memref<?x?x?x?x?xf32>) -> memref<?x?xf32>
|
||||
{
|
||||
%0 = linalg.reshape %arg0
|
||||
[affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>,
|
||||
affine_map<(d0, d1, d2, d3, d4) -> (d2)>,
|
||||
affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>] :
|
||||
memref<?x?x?x?x?xf32> into memref<?x?x?xf32>
|
||||
%1 = linalg.reshape %0
|
||||
[affine_map<(d0, d1, d2) -> (d0, d1)>,
|
||||
affine_map<(d0, d1, d2) -> (d2)>] :
|
||||
memref<?x?x?xf32> into memref<?x?xf32>
|
||||
return %1 : memref<?x?xf32>
|
||||
}
|
||||
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
|
||||
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>
|
||||
// CHECK-LABEL: collapsing_memref_reshapes
|
||||
// CHECK: linalg.reshape %{{.*}} [#[[MAP0]], #[[MAP1]]]
|
||||
// CHECK-NOT: linalg.reshape
|
||||
|
||||
// -----
|
||||
|
||||
func @expanding_memref_reshapes(%arg0 : memref<?x?xf32>) -> memref<?x?x?x?x?xf32>
|
||||
{
|
||||
%0 = linalg.reshape %arg0
|
||||
[affine_map<(d0, d1, d2) -> (d0, d1)>,
|
||||
affine_map<(d0, d1, d2) -> (d2)>] :
|
||||
memref<?x?xf32> into memref<?x?x?xf32>
|
||||
%1 = linalg.reshape %0
|
||||
[affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>,
|
||||
affine_map<(d0, d1, d2, d3, d4) -> (d2)>,
|
||||
affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>] :
|
||||
memref<?x?x?xf32> into memref<?x?x?x?x?xf32>
|
||||
return %1 : memref<?x?x?x?x?xf32>
|
||||
}
|
||||
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
|
||||
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>
|
||||
// CHECK-LABEL: expanding_memref_reshapes
|
||||
// CHECK: linalg.reshape %{{.*}} [#[[MAP0]], #[[MAP1]]]
|
||||
// CHECK-NOT: linalg.reshape
|
||||
|
||||
// -----
|
||||
|
||||
func @fold_tensor_reshape(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32>
|
||||
{
|
||||
%0 = linalg.tensor_reshape %arg0
|
||||
[affine_map<(d0, d1, d2) -> (d0, d1)>,
|
||||
affine_map<(d0, d1, d2) -> (d2)>] :
|
||||
tensor<12x4xf32> into tensor<3x4x4xf32>
|
||||
%1 = linalg.tensor_reshape %0
|
||||
[affine_map<(d0, d1, d2) -> (d0, d1)>,
|
||||
affine_map<(d0, d1, d2) -> (d2)>] :
|
||||
tensor<3x4x4xf32> into tensor<12x4xf32>
|
||||
return %1 : tensor<12x4xf32>
|
||||
}
|
||||
// CHECK-LABEL: @fold_tensor_reshape
|
||||
// CHECK-NOT: linalg.tensor_reshape
|
||||
|
||||
// -----
|
||||
|
||||
func @no_fold_tensor_reshape(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
{
|
||||
%0 = linalg.tensor_reshape %arg0
|
||||
[affine_map<(d0, d1, d2) -> (d0, d1)>,
|
||||
affine_map<(d0, d1, d2) -> (d2)>] :
|
||||
tensor<?x?xf32> into tensor<?x?x?xf32>
|
||||
%1 = linalg.tensor_reshape %0
|
||||
[affine_map<(d0, d1, d2) -> (d0, d1)>,
|
||||
affine_map<(d0, d1, d2) -> (d2)>] :
|
||||
tensor<?x?x?xf32> into tensor<?x?xf32>
|
||||
return %1 : tensor<?x?xf32>
|
||||
}
|
||||
// CHECK-LABEL: @no_fold_tensor_reshape
|
||||
// CHECK: linalg.tensor_reshape
|
||||
// CHECK: linalg.tensor_reshape
|
||||
|
||||
// -----
|
||||
|
||||
func @fold_memref_reshape(%arg0 : memref<12x4xf32>) -> memref<12x4xf32>
|
||||
{
|
||||
%0 = linalg.reshape %arg0
|
||||
[affine_map<(d0, d1, d2) -> (d0, d1)>,
|
||||
affine_map<(d0, d1, d2) -> (d2)>] :
|
||||
memref<12x4xf32> into memref<3x4x4xf32>
|
||||
%1 = linalg.reshape %0
|
||||
[affine_map<(d0, d1, d2) -> (d0, d1)>,
|
||||
affine_map<(d0, d1, d2) -> (d2)>] :
|
||||
memref<3x4x4xf32> into memref<12x4xf32>
|
||||
return %1 : memref<12x4xf32>
|
||||
}
|
||||
// CHECK-LABEL: @fold_memref_reshape
|
||||
// CHECK-NOT: linalg.reshape
|
||||
|
||||
// -----
|
||||
|
||||
func @no_fold_memref_reshape(%arg0 : memref<?x?xf32>) -> memref<?x?xf32>
|
||||
{
|
||||
%0 = linalg.reshape %arg0
|
||||
[affine_map<(d0, d1, d2) -> (d0, d1)>,
|
||||
affine_map<(d0, d1, d2) -> (d2)>] :
|
||||
memref<?x?xf32> into memref<?x?x?xf32>
|
||||
%1 = linalg.reshape %0
|
||||
[affine_map<(d0, d1, d2) -> (d0, d1)>,
|
||||
affine_map<(d0, d1, d2) -> (d2)>] :
|
||||
memref<?x?x?xf32> into memref<?x?xf32>
|
||||
return %1 : memref<?x?xf32>
|
||||
}
|
||||
// CHECK-LABEL: @no_fold_memref_reshape
|
||||
// CHECK: linalg.reshape
|
||||
// CHECK: linalg.reshape
|
||||
|
|
|
@ -214,72 +214,76 @@ func @matmul_vec_indexed(%A: !matrix_type_A,
|
|||
// CHECK-SAME: !llvm<"<4 x float>*">, !llvm<"<4 x float>*">, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64
|
||||
// CHECK-SAME: !llvm<"[4 x <4 x float>]*">, !llvm<"[4 x <4 x float>]*">, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64
|
||||
|
||||
func @reshape_static(%arg0: memref<3x4x5xf32>) {
|
||||
// Reshapes that expand and collapse back a contiguous tensor with some 1's.
|
||||
func @reshape_static_expand(%arg0: memref<3x4x5xf32>) -> memref<1x3x4x1x5xf32> {
|
||||
// Reshapes that expand a contiguous tensor with some 1's.
|
||||
%0 = linalg.reshape %arg0 [affine_map<(i, j, k, l, m) -> (i, j)>,
|
||||
affine_map<(i, j, k, l, m) -> (k)>,
|
||||
affine_map<(i, j, k, l, m) -> (l, m)>] :
|
||||
memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
|
||||
%r0 = linalg.reshape %0 [affine_map<(i, j, k, l, m) -> (i, j)>,
|
||||
affine_map<(i, j, k, l, m) -> (k)>,
|
||||
affine_map<(i, j, k, l, m) -> (l, m)>] :
|
||||
memref<1x3x4x1x5xf32> into memref<3x4x5xf32>
|
||||
return
|
||||
return %0 : memref<1x3x4x1x5xf32>
|
||||
}
|
||||
// CHECK-LABEL: func @reshape_static(
|
||||
// CHECK: llvm.mlir.undef : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
|
||||
// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
|
||||
// CHECK: llvm.insertvalue {{.*}}[0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
|
||||
// CHECK: llvm.extractvalue {{.*}}[1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
|
||||
// CHECK: llvm.insertvalue {{.*}}[1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
|
||||
// CHECK: llvm.extractvalue {{.*}}[2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
|
||||
// CHECK: llvm.insertvalue {{.*}}[2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
|
||||
// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
|
||||
// CHECK: llvm.insertvalue {{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
|
||||
// CHECK: llvm.mlir.constant(3 : index) : !llvm.i64
|
||||
// CHECK: llvm.insertvalue {{.*}}[3, 1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
|
||||
// CHECK: llvm.mlir.constant(4 : index) : !llvm.i64
|
||||
// CHECK: llvm.insertvalue {{.*}}[3, 2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
|
||||
// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
|
||||
// CHECK: llvm.insertvalue {{.*}}[3, 3] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
|
||||
// CHECK: llvm.mlir.constant(5 : index) : !llvm.i64
|
||||
// CHECK: llvm.insertvalue {{.*}}[3, 4] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
|
||||
// CHECK: llvm.mlir.constant(60 : index) : !llvm.i64
|
||||
// CHECK: llvm.insertvalue {{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
|
||||
// CHECK: llvm.mlir.constant(20 : index) : !llvm.i64
|
||||
// CHECK: llvm.insertvalue {{.*}}[4, 1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
|
||||
// CHECK: llvm.mlir.constant(5 : index) : !llvm.i64
|
||||
// CHECK: llvm.insertvalue {{.*}}[4, 2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
|
||||
// CHECK: llvm.mlir.constant(5 : index) : !llvm.i64
|
||||
// CHECK: llvm.insertvalue {{.*}}[4, 3] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
|
||||
// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
|
||||
// CHECK: llvm.insertvalue {{.*}}[4, 4] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
|
||||
// CHECK: llvm.mlir.undef : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
|
||||
// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
|
||||
// CHECK: llvm.insertvalue {{.*}}[0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
|
||||
// CHECK: llvm.extractvalue {{.*}}[1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
|
||||
// CHECK: llvm.insertvalue {{.*}}[1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
|
||||
// CHECK: llvm.extractvalue {{.*}}[2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
|
||||
// CHECK: llvm.insertvalue {{.*}}[2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
|
||||
// CHECK: llvm.mlir.constant(3 : index) : !llvm.i64
|
||||
// CHECK: llvm.insertvalue {{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
|
||||
// CHECK: llvm.mlir.constant(4 : index) : !llvm.i64
|
||||
// CHECK: llvm.insertvalue {{.*}}[3, 1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
|
||||
// CHECK: llvm.mlir.constant(5 : index) : !llvm.i64
|
||||
// CHECK: llvm.insertvalue {{.*}}[3, 2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
|
||||
// CHECK: llvm.mlir.constant(20 : index) : !llvm.i64
|
||||
// CHECK: llvm.insertvalue {{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
|
||||
// CHECK: llvm.mlir.constant(5 : index) : !llvm.i64
|
||||
// CHECK: llvm.insertvalue {{.*}}[4, 1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
|
||||
// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
|
||||
// CHECK: llvm.insertvalue {{.*}}[4, 2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
|
||||
// CHECK-LABEL: func @reshape_static_expand
|
||||
// CHECK: llvm.mlir.undef : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
|
||||
// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
|
||||
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
|
||||
// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
|
||||
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
|
||||
// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
|
||||
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
|
||||
// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
|
||||
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
|
||||
// CHECK: llvm.mlir.constant(3 : index) : !llvm.i64
|
||||
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
|
||||
// CHECK: llvm.mlir.constant(4 : index) : !llvm.i64
|
||||
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
|
||||
// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
|
||||
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 3] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
|
||||
// CHECK: llvm.mlir.constant(5 : index) : !llvm.i64
|
||||
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 4] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
|
||||
// CHECK: llvm.mlir.constant(60 : index) : !llvm.i64
|
||||
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
|
||||
// CHECK: llvm.mlir.constant(20 : index) : !llvm.i64
|
||||
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
|
||||
// CHECK: llvm.mlir.constant(5 : index) : !llvm.i64
|
||||
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
|
||||
// CHECK: llvm.mlir.constant(5 : index) : !llvm.i64
|
||||
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 3] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
|
||||
// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
|
||||
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 4] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
|
||||
|
||||
func @reshape_zero_dim(%arg0 : memref<1x1xf32>) {
|
||||
%0 = linalg.reshape %arg0 [] : memref<1x1xf32> into memref<f32>
|
||||
%1 = linalg.reshape %0 [] : memref<f32> into memref<1x1xf32>
|
||||
return
|
||||
func @reshape_static_collapse(%arg0: memref<1x3x4x1x5xf32>) -> memref<3x4x5xf32> {
|
||||
%0 = linalg.reshape %arg0 [affine_map<(i, j, k, l, m) -> (i, j)>,
|
||||
affine_map<(i, j, k, l, m) -> (k)>,
|
||||
affine_map<(i, j, k, l, m) -> (l, m)>] :
|
||||
memref<1x3x4x1x5xf32> into memref<3x4x5xf32>
|
||||
return %0 : memref<3x4x5xf32>
|
||||
}
|
||||
// CHECK-LABEL: func @reshape_zero_dim
|
||||
// CHECK-LABEL: func @reshape_static_collapse
|
||||
// CHECK: llvm.mlir.undef : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
|
||||
// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
|
||||
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
|
||||
// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
|
||||
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
|
||||
// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }">
|
||||
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
|
||||
// CHECK: llvm.mlir.constant(3 : index) : !llvm.i64
|
||||
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
|
||||
// CHECK: llvm.mlir.constant(4 : index) : !llvm.i64
|
||||
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
|
||||
// CHECK: llvm.mlir.constant(5 : index) : !llvm.i64
|
||||
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
|
||||
// CHECK: llvm.mlir.constant(20 : index) : !llvm.i64
|
||||
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
|
||||
// CHECK: llvm.mlir.constant(5 : index) : !llvm.i64
|
||||
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
|
||||
// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
|
||||
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[4, 2] : !llvm<"{ float*, float*, i64, [3 x i64], [3 x i64] }">
|
||||
|
||||
func @reshape_fold_zero_dim(%arg0 : memref<1x1xf32>) -> memref<f32> {
|
||||
%0 = linalg.reshape %arg0 [] : memref<1x1xf32> into memref<f32>
|
||||
return %0 : memref<f32>
|
||||
}
|
||||
// CHECK-LABEL: func @reshape_fold_zero_dim
|
||||
// CHECK: llvm.mlir.undef : !llvm<"{ float*, float*, i64 }">
|
||||
// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, float*, i64 }">
|
||||
|
@ -287,6 +291,12 @@ func @reshape_zero_dim(%arg0 : memref<1x1xf32>) {
|
|||
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ float*, float*, i64 }">
|
||||
// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm<"{ float*, float*, i64 }">
|
||||
|
||||
func @reshape_expand_zero_dim(%arg0 : memref<f32>) -> memref<1x1xf32> {
|
||||
%0 = linalg.reshape %arg0 [] : memref<f32> into memref<1x1xf32>
|
||||
return %0 : memref<1x1xf32>
|
||||
}
|
||||
// CHECK-LABEL: func @reshape_expand_zero_dim
|
||||
// CHECK: llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, float*, i64 }">
|
||||
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
|
|
Loading…
Reference in New Issue