forked from OSchip/llvm-project
[mlir][Linalg] Disallow ops with index semantics in `PushExpandingReshape`.
This pattern is not written to handle operations with `linalg.index` operations in its body, i.e. operations that have index semantics. Differential Revision: https://reviews.llvm.org/D117856
This commit is contained in:
parent
93230ac1d2
commit
e5a315f57a
|
@ -994,7 +994,7 @@ struct PushExpandingReshape : public OpRewritePattern<GenericOp> {
|
|||
LogicalResult matchAndRewrite(GenericOp genericOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Only apply to elementwise linalg on tensor.
|
||||
if (!genericOp.hasTensorSemantics() ||
|
||||
if (!genericOp.hasTensorSemantics() || genericOp.hasIndexSemantics() ||
|
||||
genericOp.getNumParallelLoops() != genericOp.getNumLoops())
|
||||
return failure();
|
||||
// Only support identity output maps. It could be extended to permuations if
|
||||
|
|
|
@ -124,3 +124,30 @@ func @type_correctness(%arg0 : tensor<6x5xi32>, %arg1 : tensor<5xf32>,
|
|||
// CHECK-SAME: outs(%{{.+}} : tensor<6x5xf32>)
|
||||
// CHECK: tensor.expand_shape %[[OP]]
|
||||
// CHECK-SAME: tensor<6x5xf32> into tensor<2x3x5xf32>
|
||||
|
||||
// -----
|
||||
|
||||
func @generic_op_index_semantics(%A: tensor<?x16xi64>, %B: tensor<16xi64>, %init: tensor<?x112x16xi64>) -> tensor<?x112x16xi64> {
|
||||
%0 = tensor.expand_shape %A [[0, 1], [2]]
|
||||
: tensor<?x16xi64> into tensor<?x112x16xi64>
|
||||
%2 = linalg.generic {indexing_maps = [
|
||||
affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
|
||||
iterator_types = ["parallel", "parallel", "parallel"]}
|
||||
ins(%0, %B : tensor<?x112x16xi64>, tensor<16xi64>)
|
||||
outs(%init : tensor<?x112x16xi64>) {
|
||||
^bb0(%arg1: i64, %arg2: i64, %arg3: i64): // no predecessors
|
||||
%index = linalg.index 0 : index
|
||||
%1 = arith.index_cast %index : index to i64
|
||||
%add = arith.addi %arg1, %1 : i64
|
||||
%s = arith.subi %add, %arg2 : i64
|
||||
linalg.yield %s : i64
|
||||
} -> tensor<?x112x16xi64>
|
||||
return %2 : tensor<?x112x16xi64>
|
||||
}
|
||||
// CHECK: func @generic_op_index_semantics
|
||||
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x16xi64>
|
||||
// CHECK: %[[RESHAPE:.+]] = tensor.expand_shape %[[ARG0]]
|
||||
// CHECK: %[[RESULT:.+]] = linalg.generic
|
||||
// CHECK-SAME: ins(%[[RESHAPE]]
|
||||
// CHECK: return %[[RESULT]]
|
||||
|
|
Loading…
Reference in New Issue