From 08f0cb77197dc2842baa00f22f0264fa49d1475a Mon Sep 17 00:00:00 2001 From: thomasraoux Date: Fri, 17 Sep 2021 10:56:21 -0700 Subject: [PATCH] [mlir] Prevent crash in DropUnitDim pattern due to tensor with encoding Differential Revision: https://reviews.llvm.org/D109984 --- .../Linalg/Transforms/DropUnitDims.cpp | 6 ++++ .../Dialect/Linalg/drop-unit-extent-dims.mlir | 31 +++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index 8315de4c72e7..98a06bfda97a 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -361,6 +361,12 @@ struct ReplaceUnitExtents : public OpRewritePattern { LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { + // Skip the pattern if the op has any tensor with special encoding. + if (llvm::any_of(genericOp->getOperandTypes(), [](Type type) { + auto tensorType = type.dyn_cast(); + return tensorType && tensorType.getEncoding() != nullptr; + })) + return failure(); MLIRContext *context = rewriter.getContext(); Location loc = genericOp.getLoc(); diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir index 60ad72300a18..53bdf0aa712a 100644 --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -796,3 +796,34 @@ func @input_stays_same(%arg0 : memref, %arg1 : f32, %shape: me // CHECK: linalg.yield %[[ARG]] : f32 // CHECK: } // CHECK: return %[[ARG2]] : memref + +// ----- + +// Negative test for case with tensor encoding. +#matvec = { + indexing_maps = [ + affine_map<(i,j) -> (i,j)>, // A + affine_map<(i,j) -> (j)>, // b + affine_map<(i,j) -> (i)> // x (out) + ], + iterator_types = ["parallel", "reduction"] +} + +#CSR = #sparse_tensor.encoding<{ dimLevelType = ["dense", "compressed"] }> + +func @sparse_case(%arg0: tensor<8x8xf32, #CSR>, %arg1: tensor<8xf32>) -> tensor<8xf32> { + %0 = linalg.init_tensor [8] : tensor<8xf32> + %1 = linalg.generic #matvec + ins(%arg0, %arg1: tensor<8x8xf32, #CSR>, tensor<8xf32>) + outs(%0: tensor<8xf32>) { + ^bb(%a: f32, %b: f32, %x: f32): + %m = mulf %a, %b : f32 + %add = addf %x, %m : f32 + linalg.yield %add : f32 + } -> tensor<8xf32> + return %1: tensor<8xf32> +} + +// CHECK-LABEL: func @sparse_case +// CHECK-NEXT: linalg.init_tensor +// CHECK-NEXT: linalg.generic