[mlir][Linalg] Disallow ops with index semantics in `PushExpandingReshape`.

This pattern is not written to handle operations with `linalg.index`
operations in its body, i.e. operations that have index semantics.

Differential Revision: https://reviews.llvm.org/D117856
This commit is contained in:
MaheshRavishankar 2022-01-25 10:36:34 -08:00
parent 93230ac1d2
commit e5a315f57a
2 changed files with 28 additions and 1 deletions

View File

@ -994,7 +994,7 @@ struct PushExpandingReshape : public OpRewritePattern<GenericOp> {
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
// Only apply to elementwise linalg on tensor.
if (!genericOp.hasTensorSemantics() ||
if (!genericOp.hasTensorSemantics() || genericOp.hasIndexSemantics() ||
genericOp.getNumParallelLoops() != genericOp.getNumLoops())
return failure();
// Only support identity output maps. It could be extended to permuations if

View File

@ -124,3 +124,30 @@ func @type_correctness(%arg0 : tensor<6x5xi32>, %arg1 : tensor<5xf32>,
// CHECK-SAME: outs(%{{.+}} : tensor<6x5xf32>)
// CHECK: tensor.expand_shape %[[OP]]
// CHECK-SAME: tensor<6x5xf32> into tensor<2x3x5xf32>
// -----
func @generic_op_index_semantics(%A: tensor<?x16xi64>, %B: tensor<16xi64>, %init: tensor<?x112x16xi64>) -> tensor<?x112x16xi64> {
%0 = tensor.expand_shape %A [[0, 1], [2]]
: tensor<?x16xi64> into tensor<?x112x16xi64>
%2 = linalg.generic {indexing_maps = [
affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d2)>,
affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
iterator_types = ["parallel", "parallel", "parallel"]}
ins(%0, %B : tensor<?x112x16xi64>, tensor<16xi64>)
outs(%init : tensor<?x112x16xi64>) {
^bb0(%arg1: i64, %arg2: i64, %arg3: i64): // no predecessors
%index = linalg.index 0 : index
%1 = arith.index_cast %index : index to i64
%add = arith.addi %arg1, %1 : i64
%s = arith.subi %add, %arg2 : i64
linalg.yield %s : i64
} -> tensor<?x112x16xi64>
return %2 : tensor<?x112x16xi64>
}
// CHECK: func @generic_op_index_semantics
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x16xi64>
// CHECK: %[[RESHAPE:.+]] = tensor.expand_shape %[[ARG0]]
// CHECK: %[[RESULT:.+]] = linalg.generic
// CHECK-SAME: ins(%[[RESHAPE]]
// CHECK: return %[[RESULT]]