[mlir][linalg] Support dropping unit dimensions for init tensors

init tensor operands also has indexing map and generally follow
the same constraints we expect for non-init-tensor operands.

Differential Revision: https://reviews.llvm.org/D99115
This commit is contained in:
Lei Zhang 2021-03-24 17:51:44 -04:00
parent 7f28d27cb6
commit c241e1c2f5
2 changed files with 36 additions and 2 deletions

View File

@ -294,8 +294,8 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOpTy> {
using OpRewritePattern<GenericOpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(GenericOpTy op,
PatternRewriter &rewriter) const override {
// TODO: support init_tensors and reductions.
if (!op.hasTensorSemantics() || op.getNumInitTensors() != 0)
// TODO: support reductions.
if (!op.hasTensorSemantics())
return failure();
MLIRContext *context = rewriter.getContext();

View File

@ -354,3 +354,37 @@ func @fold_unit_dim_tensor_reshape_op(%arg0 : tensor<5xf32>) -> tensor<2x5xf32>
// CHECK-LABEL: func @fold_unit_dim_tensor_reshape_op
// CHECK: %[[RESULT:.+]] = linalg.generic
// CHECK: return %[[RESULT]]
// -----
func @fold_unit_dim_for_init_tensor(%input: tensor<1x1000xf32>) -> tensor<1xf32> {
%cst = constant 0.0 : f32
%init = linalg.init_tensor [1] : tensor<1xf32>
%fill = linalg.fill(%init, %cst) : tensor<1xf32>, f32 -> tensor<1xf32>
%add = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
iterator_types = ["parallel", "reduction"]}
ins(%input : tensor<1x1000xf32>)outs(%fill : tensor<1xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%1823 = addf %arg1, %arg2 : f32
linalg.yield %1823 : f32
} -> tensor<1xf32>
return %add : tensor<1xf32>
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> ()>
// CHECK: func @fold_unit_dim_for_init_tensor
// CHECK: %[[INPUT_RESHAPE:.+]] = linalg.tensor_reshape %{{.+}} [#[[MAP0]]] : tensor<1x1000xf32> into tensor<1000xf32>
// CHECK: %[[INIT_RESHAPE:.+]] = linalg.tensor_reshape %{{.+}} [] : tensor<1xf32> into tensor<f32>
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP2]]]
// CHECK-SAME: iterator_types = ["reduction"]
// CHECK-SAME: ins(%[[INPUT_RESHAPE]] : tensor<1000xf32>)
// CHECK-SAME: outs(%[[INIT_RESHAPE]] : tensor<f32>)
// CHECK: %[[GENERIC_RESHAPE:.+]] = linalg.tensor_reshape %[[GENERIC]] [] : tensor<f32> into tensor<1xf32>
// CHECK: return %[[GENERIC_RESHAPE:.+]] : tensor<1xf32>