From e5a315f57acf5580aa8819123300d90b4f7a160a Mon Sep 17 00:00:00 2001 From: MaheshRavishankar Date: Tue, 25 Jan 2022 10:36:34 -0800 Subject: [PATCH] [mlir][Linalg] Disallow ops with index semantics in `PushExpandingReshape`. This pattern is not written to handle operations with `linalg.index` operations in its body, i.e. operations that have index semantics. Differential Revision: https://reviews.llvm.org/D117856 --- .../Linalg/Transforms/ElementwiseOpFusion.cpp | 2 +- .../Dialect/Linalg/fusion-push-reshape.mlir | 27 +++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index be34ef8bbd62..aaa5d4c38620 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -994,7 +994,7 @@ struct PushExpandingReshape : public OpRewritePattern { LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { // Only apply to elementwise linalg on tensor. - if (!genericOp.hasTensorSemantics() || + if (!genericOp.hasTensorSemantics() || genericOp.hasIndexSemantics() || genericOp.getNumParallelLoops() != genericOp.getNumLoops()) return failure(); // Only support identity output maps. It could be extended to permuations if diff --git a/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir b/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir index 9e96c98e7850..0c02ff8c54d1 100644 --- a/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir +++ b/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir @@ -124,3 +124,30 @@ func @type_correctness(%arg0 : tensor<6x5xi32>, %arg1 : tensor<5xf32>, // CHECK-SAME: outs(%{{.+}} : tensor<6x5xf32>) // CHECK: tensor.expand_shape %[[OP]] // CHECK-SAME: tensor<6x5xf32> into tensor<2x3x5xf32> + +// ----- + +func @generic_op_index_semantics(%A: tensor, %B: tensor<16xi64>, %init: tensor) -> tensor { + %0 = tensor.expand_shape %A [[0, 1], [2]] + : tensor into tensor + %2 = linalg.generic {indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d2)>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)>], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%0, %B : tensor, tensor<16xi64>) + outs(%init : tensor) { + ^bb0(%arg1: i64, %arg2: i64, %arg3: i64): // no predecessors + %index = linalg.index 0 : index + %1 = arith.index_cast %index : index to i64 + %add = arith.addi %arg1, %1 : i64 + %s = arith.subi %add, %arg2 : i64 + linalg.yield %s : i64 + } -> tensor + return %2 : tensor +} +// CHECK: func @generic_op_index_semantics +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK: %[[RESHAPE:.+]] = tensor.expand_shape %[[ARG0]] +// CHECK: %[[RESULT:.+]] = linalg.generic +// CHECK-SAME: ins(%[[RESHAPE]] +// CHECK: return %[[RESULT]]