forked from OSchip/llvm-project
[mlir][Linalg] Drop filter-based splitReduction
This transformation is available and tested via the transform dialect. Differential Revision: https://reviews.llvm.org/D135767
This commit is contained in:
parent
bbe4441d33
commit
e0cea169f7
|
@ -1050,7 +1050,6 @@ using ControlSplitReductionFn =
|
||||||
void populateSplitReductionPattern(
|
void populateSplitReductionPattern(
|
||||||
RewritePatternSet &patterns,
|
RewritePatternSet &patterns,
|
||||||
const ControlSplitReductionFn &controlSplitReductionFn,
|
const ControlSplitReductionFn &controlSplitReductionFn,
|
||||||
const LinalgTransformationFilter &f = LinalgTransformationFilter(),
|
|
||||||
bool useAlloc = false);
|
bool useAlloc = false);
|
||||||
|
|
||||||
/// Apply transformation to split the single linalg op reduction into a parallel
|
/// Apply transformation to split the single linalg op reduction into a parallel
|
||||||
|
@ -1094,14 +1093,6 @@ void populateSplitReductionPattern(
|
||||||
/// linalg.yield %5 : f32
|
/// linalg.yield %5 : f32
|
||||||
/// } -> tensor<f32>
|
/// } -> tensor<f32>
|
||||||
/// ```
|
/// ```
|
||||||
FailureOr<LinalgOp>
|
|
||||||
splitReduction(PatternRewriter &b, LinalgOp op,
|
|
||||||
const ControlSplitReductionFn &controlSplitReductionFn,
|
|
||||||
const LinalgTransformationFilter &f, bool useAlloc = false);
|
|
||||||
|
|
||||||
/// Filterless version of the above.
|
|
||||||
/// Returns both the new linalg ops as well as the fillOp needed to initialize
|
|
||||||
/// the temporary expanded tensor with the proper neutral element.
|
|
||||||
struct SplitReductionResult {
|
struct SplitReductionResult {
|
||||||
Operation *initOrAlloc;
|
Operation *initOrAlloc;
|
||||||
FillOp fillOp;
|
FillOp fillOp;
|
||||||
|
|
|
@ -58,26 +58,6 @@ static Attribute getNeutralElement(Operation *op) {
|
||||||
return Attribute();
|
return Attribute();
|
||||||
}
|
}
|
||||||
|
|
||||||
FailureOr<LinalgOp> mlir::linalg::splitReduction(
|
|
||||||
PatternRewriter &b, LinalgOp op,
|
|
||||||
const ControlSplitReductionFn &controlSplitReductionFn,
|
|
||||||
const LinalgTransformationFilter &filter, bool useAlloc) {
|
|
||||||
if (failed(filter.checkAndNotify(b, op)) || !op.hasTensorSemantics() ||
|
|
||||||
op.getNumReductionLoops() != 1 || op.getNumOutputs() != 1 ||
|
|
||||||
!op.hasOnlyProjectedPermutations())
|
|
||||||
return b.notifyMatchFailure(op, "precondition not met");
|
|
||||||
|
|
||||||
FailureOr<SplitReductionResult> res =
|
|
||||||
splitReduction(b, op, controlSplitReductionFn, useAlloc);
|
|
||||||
if (failed(res))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
filter.replaceLinalgTransformationFilter(b, res->splitLinalgOp);
|
|
||||||
filter.replaceLinalgTransformationFilter(b, res->resultCombiningLinalgOp);
|
|
||||||
|
|
||||||
return res->splitLinalgOp;
|
|
||||||
}
|
|
||||||
|
|
||||||
FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
|
FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
|
||||||
PatternRewriter &b, LinalgOp op,
|
PatternRewriter &b, LinalgOp op,
|
||||||
const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) {
|
const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) {
|
||||||
|
@ -481,30 +461,26 @@ struct LinalgSplitReduction : public OpInterfaceRewritePattern<LinalgOp> {
|
||||||
/// Construct a generic pattern applied to all LinalgOp that verify `filter`.
|
/// Construct a generic pattern applied to all LinalgOp that verify `filter`.
|
||||||
LinalgSplitReduction(MLIRContext *context,
|
LinalgSplitReduction(MLIRContext *context,
|
||||||
ControlSplitReductionFn controlSplitReductionFn,
|
ControlSplitReductionFn controlSplitReductionFn,
|
||||||
LinalgTransformationFilter f, bool useAlloc = false,
|
bool useAlloc = false, PatternBenefit benefit = 1)
|
||||||
PatternBenefit benefit = 1)
|
|
||||||
: OpInterfaceRewritePattern<LinalgOp>(context, benefit),
|
: OpInterfaceRewritePattern<LinalgOp>(context, benefit),
|
||||||
controlSplitReductionFn(std::move(controlSplitReductionFn)),
|
controlSplitReductionFn(std::move(controlSplitReductionFn)),
|
||||||
useAlloc(useAlloc), filter(std::move(f)) {}
|
useAlloc(useAlloc) {}
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(LinalgOp op,
|
LogicalResult matchAndRewrite(LinalgOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
return splitReduction(rewriter, op, controlSplitReductionFn, filter,
|
return splitReduction(rewriter, op, controlSplitReductionFn, useAlloc);
|
||||||
useAlloc);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
ControlSplitReductionFn controlSplitReductionFn;
|
ControlSplitReductionFn controlSplitReductionFn;
|
||||||
bool useAlloc;
|
bool useAlloc;
|
||||||
LinalgTransformationFilter filter;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void linalg::populateSplitReductionPattern(
|
void linalg::populateSplitReductionPattern(
|
||||||
RewritePatternSet &patterns,
|
RewritePatternSet &patterns,
|
||||||
const ControlSplitReductionFn &controlSplitReductionFn,
|
const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) {
|
||||||
const LinalgTransformationFilter &f, bool useAlloc) {
|
|
||||||
patterns.add<LinalgSplitReduction>(patterns.getContext(),
|
patterns.add<LinalgSplitReduction>(patterns.getContext(),
|
||||||
controlSplitReductionFn, f, useAlloc);
|
controlSplitReductionFn, useAlloc);
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,193 +0,0 @@
|
||||||
// RUN: mlir-opt %s -test-linalg-transform-patterns=test-split-reduction -split-input-file | FileCheck %s
|
|
||||||
// RUN: mlir-opt %s -test-linalg-transform-patterns=test-split-reduction-inner-parallel -split-input-file | FileCheck %s --check-prefix=INNERPARALLELCHECK
|
|
||||||
|
|
||||||
func.func @matmul_split(%A : tensor<16x256xf32>, %B: tensor<256x32xf32>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
|
|
||||||
%0 = linalg.matmul ins(%A, %B: tensor<16x256xf32>, tensor<256x32xf32>)
|
|
||||||
outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
|
|
||||||
return %0: tensor<16x32xf32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
|
|
||||||
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d3, d1)>
|
|
||||||
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
|
|
||||||
// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
|
||||||
// CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
|
|
||||||
// CHECK-LABEL: @matmul_split
|
|
||||||
// CHECK-DAG: %[[ID:.*]] = arith.constant 0.000000e+00 : f32
|
|
||||||
// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<16x256xf32> into tensor<16x4x64xf32>
|
|
||||||
// CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] : tensor<256x32xf32> into tensor<4x64x32xf32>
|
|
||||||
// CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<16x32x4xf32>
|
|
||||||
// CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<16x32x4xf32>) -> tensor<16x32x4xf32>
|
|
||||||
// CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
|
|
||||||
// CHECK-SAME: , iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
|
|
||||||
// CHECK-SAME: ins(%[[I1]], %[[I2]] : tensor<16x4x64xf32>, tensor<4x64x32xf32>) outs(%[[F]] : tensor<16x32x4xf32>) {
|
|
||||||
// CHECK: arith.mulf
|
|
||||||
// CHECK: arith.addf
|
|
||||||
// CHECK: linalg.yield
|
|
||||||
// CHECK: } -> tensor<16x32x4xf32>
|
|
||||||
// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]]],
|
|
||||||
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]} ins(%[[G]] : tensor<16x32x4xf32>) outs(%{{.*}} : tensor<16x32xf32>) {
|
|
||||||
// CHECK: arith.addf
|
|
||||||
// CHECK: linalg.yield %{{.*}} : f32
|
|
||||||
// CHECK: } -> tensor<16x32xf32>
|
|
||||||
// CHECK: return %[[R]] : tensor<16x32xf32>
|
|
||||||
|
|
||||||
// INNERPARALLELCHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
|
|
||||||
// INNERPARALLELCHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d3, d1)>
|
|
||||||
// INNERPARALLELCHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
|
|
||||||
// INNERPARALLELCHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
|
||||||
// INNERPARALLELCHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
|
|
||||||
// INNERPARALLELCHECK-LABEL: @matmul_split
|
|
||||||
// INNERPARALLELCHECK-DAG: %[[ID:.*]] = arith.constant 0.000000e+00 : f32
|
|
||||||
// INNERPARALLELCHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<16x256xf32> into tensor<16x64x4xf32>
|
|
||||||
// INNERPARALLELCHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] : tensor<256x32xf32> into tensor<64x4x32xf32>
|
|
||||||
// INNERPARALLELCHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<16x32x4xf32>
|
|
||||||
// INNERPARALLELCHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<16x32x4xf32>) -> tensor<16x32x4xf32>
|
|
||||||
// INNERPARALLELCHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
|
|
||||||
// INNERPARALLELCHECK-SAME: , iterator_types = ["parallel", "parallel", "reduction", "parallel"]}
|
|
||||||
// INNERPARALLELCHECK-SAME: ins(%[[I1]], %[[I2]] : tensor<16x64x4xf32>, tensor<64x4x32xf32>) outs(%[[F]] : tensor<16x32x4xf32>) {
|
|
||||||
// INNERPARALLELCHECK: arith.mulf
|
|
||||||
// INNERPARALLELCHECK: arith.addf
|
|
||||||
// INNERPARALLELCHECK: linalg.yield
|
|
||||||
// INNERPARALLELCHECK: } -> tensor<16x32x4xf32>
|
|
||||||
// INNERPARALLELCHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]]],
|
|
||||||
// INNERPARALLELCHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]} ins(%[[G]] : tensor<16x32x4xf32>) outs(%{{.*}} : tensor<16x32xf32>) {
|
|
||||||
// INNERPARALLELCHECK: arith.addf
|
|
||||||
// INNERPARALLELCHECK: linalg.yield %{{.*}} : f32
|
|
||||||
// INNERPARALLELCHECK: } -> tensor<16x32xf32>
|
|
||||||
// INNERPARALLELCHECK: return %[[R]] : tensor<16x32xf32>
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
func.func @generic_split_1d(%arg0: tensor<32xf32>, %arg1: tensor<f32>, %out: tensor<f32>) -> tensor<f32> {
|
|
||||||
%red = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
|
|
||||||
affine_map<(d0) -> ()>,
|
|
||||||
affine_map<(d0) -> ()>],
|
|
||||||
iterator_types = ["reduction"]}
|
|
||||||
ins(%arg0, %arg1 : tensor<32xf32>, tensor<f32>)
|
|
||||||
outs(%out : tensor<f32>) {
|
|
||||||
^bb0(%arg7: f32, %arg8: f32, %arg9: f32):
|
|
||||||
%40 = arith.subf %arg7, %arg8 : f32
|
|
||||||
%41 = math.exp %40 : f32
|
|
||||||
%42 = arith.mulf %41, %arg9 : f32
|
|
||||||
linalg.yield %42 : f32
|
|
||||||
} -> tensor<f32>
|
|
||||||
return %red : tensor<f32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
|
||||||
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> ()>
|
|
||||||
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0)>
|
|
||||||
// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0) -> (d0)>
|
|
||||||
// CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0) -> ()>
|
|
||||||
//CHECK-LABEL: @generic_split_1d
|
|
||||||
// CHECK: %[[ID:.*]] = arith.constant 1.000000e+00 : f32
|
|
||||||
// CHECK: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1]] : tensor<32xf32> into tensor<4x8xf32>
|
|
||||||
// CHECK: %[[INI:.*]] = tensor.empty() : tensor<4xf32>
|
|
||||||
// CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<4xf32>) -> tensor<4xf32>
|
|
||||||
// CHECK: %[[G:.*]] = linalg.generic
|
|
||||||
// CHECK: {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]],
|
|
||||||
// CHECK: iterator_types = ["parallel", "reduction"]} ins(%[[I1]], %{{.*}} : tensor<4x8xf32>, tensor<f32>) outs(%[[F]] : tensor<4xf32>) {
|
|
||||||
// CHECK: arith.subf
|
|
||||||
// CHECK: math.exp
|
|
||||||
// CHECK: arith.mulf
|
|
||||||
// CHECK: linalg.yield
|
|
||||||
// CHECK: } -> tensor<4xf32>
|
|
||||||
// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]]], iterator_types = ["reduction"]} ins(%[[G]] : tensor<4xf32>) outs(%{{.*}} : tensor<f32>) {
|
|
||||||
// CHECK: arith.mulf
|
|
||||||
// CHECK: linalg.yield
|
|
||||||
// CHECK: } -> tensor<f32>
|
|
||||||
// CHECK: return %[[R]] : tensor<f32>
|
|
||||||
|
|
||||||
// INNERPARALLELCHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
|
||||||
// INNERPARALLELCHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> ()>
|
|
||||||
// INNERPARALLELCHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d1)>
|
|
||||||
// INNERPARALLELCHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0) -> (d0)>
|
|
||||||
// INNERPARALLELCHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0) -> ()>
|
|
||||||
//INNERPARALLELCHECK-LABEL: @generic_split_1d
|
|
||||||
// INNERPARALLELCHECK: %[[ID:.*]] = arith.constant 1.000000e+00 : f32
|
|
||||||
// INNERPARALLELCHECK: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1]] : tensor<32xf32> into tensor<8x4xf32>
|
|
||||||
// INNERPARALLELCHECK: %[[INI:.*]] = tensor.empty() : tensor<4xf32>
|
|
||||||
// INNERPARALLELCHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<4xf32>) -> tensor<4xf32>
|
|
||||||
// INNERPARALLELCHECK: %[[G:.*]] = linalg.generic
|
|
||||||
// INNERPARALLELCHECK: {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]],
|
|
||||||
// INNERPARALLELCHECK: iterator_types = ["reduction", "parallel"]} ins(%[[I1]], %{{.*}} : tensor<8x4xf32>, tensor<f32>) outs(%[[F]] : tensor<4xf32>) {
|
|
||||||
// INNERPARALLELCHECK: arith.subf
|
|
||||||
// INNERPARALLELCHECK: math.exp
|
|
||||||
// INNERPARALLELCHECK: arith.mulf
|
|
||||||
// INNERPARALLELCHECK: linalg.yield
|
|
||||||
// INNERPARALLELCHECK: } -> tensor<4xf32>
|
|
||||||
// INNERPARALLELCHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]]], iterator_types = ["reduction"]} ins(%[[G]] : tensor<4xf32>) outs(%{{.*}} : tensor<f32>) {
|
|
||||||
// INNERPARALLELCHECK: arith.mulf
|
|
||||||
// INNERPARALLELCHECK: linalg.yield
|
|
||||||
// INNERPARALLELCHECK: } -> tensor<f32>
|
|
||||||
// INNERPARALLELCHECK: return %[[R]] : tensor<f32>
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
func.func @generic_split_3d(%input: tensor<32x2xf32>, %input_2: tensor<5x32xf32>, %output: tensor<5x2xf32>)
|
|
||||||
-> tensor<5x2xf32>
|
|
||||||
{
|
|
||||||
%0 = linalg.generic {
|
|
||||||
indexing_maps = [
|
|
||||||
affine_map<(d0, d1, d2) -> (d1, d0)>,
|
|
||||||
affine_map<(d0, d1, d2) -> (d2, d1)>,
|
|
||||||
affine_map<(d0, d1, d2) -> (d2, d0)>
|
|
||||||
],
|
|
||||||
iterator_types = ["parallel", "reduction", "parallel"]
|
|
||||||
} ins(%input, %input_2 : tensor<32x2xf32>, tensor<5x32xf32>) outs(%output : tensor<5x2xf32>) {
|
|
||||||
^bb0(%arg0: f32, %arg1: f32, %arg2: f32):
|
|
||||||
%3 = arith.addf %arg0, %arg1 : f32
|
|
||||||
%4 = arith.maxf %3, %arg2 : f32
|
|
||||||
linalg.yield %4 : f32
|
|
||||||
} -> tensor<5x2xf32>
|
|
||||||
return %0 : tensor<5x2xf32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d1, d0)>
|
|
||||||
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d2, d1)>
|
|
||||||
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d2)>
|
|
||||||
// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
|
||||||
// CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
|
|
||||||
// CHECK-LABEL: func @generic_split_3d
|
|
||||||
// CHECK: %[[ID:.*]] = arith.constant -3.40282347E+38 : f32
|
|
||||||
// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] : tensor<32x2xf32> into tensor<4x8x2xf32>
|
|
||||||
// CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<5x32xf32> into tensor<5x4x8xf32>
|
|
||||||
// CHECK: %[[INI:.*]] = tensor.empty() : tensor<5x2x4xf32>
|
|
||||||
// CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<5x2x4xf32>) -> tensor<5x2x4xf32>
|
|
||||||
// CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel"]}
|
|
||||||
// CHECK-SAME: ins(%[[I1]], %[[I2]] : tensor<4x8x2xf32>, tensor<5x4x8xf32>) outs(%[[F]] : tensor<5x2x4xf32>) {
|
|
||||||
// CHECK: arith.addf
|
|
||||||
// CHECK: arith.maxf
|
|
||||||
// CHECK: linalg.yield
|
|
||||||
// CHECK: } -> tensor<5x2x4xf32>
|
|
||||||
// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]]], iterator_types = ["parallel", "parallel", "reduction"]}
|
|
||||||
// CHECK-SAME: ins(%[[G]] : tensor<5x2x4xf32>) outs(%{{.*}} : tensor<5x2xf32>) {
|
|
||||||
// CHECK: arith.maxf
|
|
||||||
// CHECK: linalg.yield
|
|
||||||
// CHECK: } -> tensor<5x2xf32>
|
|
||||||
// CHECK: return %[[R]] : tensor<5x2xf32>
|
|
||||||
|
|
||||||
// INNERPARALLELCHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d0)>
|
|
||||||
// INNERPARALLELCHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d1, d2)>
|
|
||||||
// INNERPARALLELCHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d2)>
|
|
||||||
// INNERPARALLELCHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
|
||||||
// INNERPARALLELCHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
|
|
||||||
// INNERPARALLELCHECK-LABEL: func @generic_split_3d
|
|
||||||
// INNERPARALLELCHECK: %[[ID:.*]] = arith.constant -3.40282347E+38 : f32
|
|
||||||
// INNERPARALLELCHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] : tensor<32x2xf32> into tensor<8x4x2xf32>
|
|
||||||
// INNERPARALLELCHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<5x32xf32> into tensor<5x8x4xf32>
|
|
||||||
// INNERPARALLELCHECK: %[[INI:.*]] = tensor.empty() : tensor<5x2x4xf32>
|
|
||||||
// INNERPARALLELCHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<5x2x4xf32>) -> tensor<5x2x4xf32>
|
|
||||||
// INNERPARALLELCHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel"]}
|
|
||||||
// INNERPARALLELCHECK-SAME: ins(%[[I1]], %[[I2]] : tensor<8x4x2xf32>, tensor<5x8x4xf32>) outs(%[[F]] : tensor<5x2x4xf32>) {
|
|
||||||
// INNERPARALLELCHECK: arith.addf
|
|
||||||
// INNERPARALLELCHECK: arith.maxf
|
|
||||||
// INNERPARALLELCHECK: linalg.yield
|
|
||||||
// INNERPARALLELCHECK: } -> tensor<5x2x4xf32>
|
|
||||||
// INNERPARALLELCHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]]], iterator_types = ["parallel", "parallel", "reduction"]}
|
|
||||||
// INNERPARALLELCHECK-SAME: ins(%[[G]] : tensor<5x2x4xf32>) outs(%{{.*}} : tensor<5x2xf32>) {
|
|
||||||
// INNERPARALLELCHECK: arith.maxf
|
|
||||||
// INNERPARALLELCHECK: linalg.yield
|
|
||||||
// INNERPARALLELCHECK: } -> tensor<5x2xf32>
|
|
||||||
// INNERPARALLELCHECK: return %[[R]] : tensor<5x2xf32>
|
|
|
@ -84,14 +84,6 @@ struct TestLinalgTransforms
|
||||||
llvm::cl::desc("Test rewrite of subtensor(tensor.pad) into "
|
llvm::cl::desc("Test rewrite of subtensor(tensor.pad) into "
|
||||||
"tensor.pad(subtensor)"),
|
"tensor.pad(subtensor)"),
|
||||||
llvm::cl::init(false)};
|
llvm::cl::init(false)};
|
||||||
Option<bool> testSplitReduction{
|
|
||||||
*this, "test-split-reduction",
|
|
||||||
llvm::cl::desc("Test split reduction transformation"),
|
|
||||||
llvm::cl::init(false)};
|
|
||||||
Option<bool> testSplitReductionInnerParallel{
|
|
||||||
*this, "test-split-reduction-inner-parallel",
|
|
||||||
llvm::cl::desc("Test split reduction with inner parallel transformation"),
|
|
||||||
llvm::cl::init(false)};
|
|
||||||
ListOption<int64_t> peeledLoops{
|
ListOption<int64_t> peeledLoops{
|
||||||
*this, "peeled-loops",
|
*this, "peeled-loops",
|
||||||
llvm::cl::desc("Loops to be peeled when test-tile-pattern")};
|
llvm::cl::desc("Loops to be peeled when test-tile-pattern")};
|
||||||
|
@ -176,34 +168,6 @@ static void applyExtractSliceOfPadTensorSwapPattern(func::FuncOp funcOp) {
|
||||||
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
|
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
|
||||||
}
|
}
|
||||||
|
|
||||||
static void applySplitReduction(func::FuncOp funcOp) {
|
|
||||||
RewritePatternSet patterns(funcOp.getContext());
|
|
||||||
linalg::populateSplitReductionPattern(
|
|
||||||
patterns,
|
|
||||||
[](LinalgOp op) {
|
|
||||||
unsigned insertDimIndex = op.getNumLoops() - 1;
|
|
||||||
return SplitReductionOptions{4, insertDimIndex, false};
|
|
||||||
},
|
|
||||||
LinalgTransformationFilter(
|
|
||||||
ArrayRef<StringAttr>{},
|
|
||||||
StringAttr::get(funcOp.getContext(), "SPLIT")));
|
|
||||||
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
|
|
||||||
}
|
|
||||||
|
|
||||||
static void applySplitReductionInnerParallel(func::FuncOp funcOp) {
|
|
||||||
RewritePatternSet patterns(funcOp.getContext());
|
|
||||||
linalg::populateSplitReductionPattern(
|
|
||||||
patterns,
|
|
||||||
[](LinalgOp op) {
|
|
||||||
unsigned insertDimIndex = op.getNumLoops() - 1;
|
|
||||||
return SplitReductionOptions{4, insertDimIndex, true};
|
|
||||||
},
|
|
||||||
LinalgTransformationFilter(
|
|
||||||
ArrayRef<StringAttr>{},
|
|
||||||
StringAttr::get(funcOp.getContext(), "SPLIT")));
|
|
||||||
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
|
|
||||||
}
|
|
||||||
|
|
||||||
static void applyBubbleUpExtractSliceOpPattern(func::FuncOp funcOp) {
|
static void applyBubbleUpExtractSliceOpPattern(func::FuncOp funcOp) {
|
||||||
RewritePatternSet patterns(funcOp.getContext());
|
RewritePatternSet patterns(funcOp.getContext());
|
||||||
populateBubbleUpExtractSliceOpPatterns(patterns);
|
populateBubbleUpExtractSliceOpPatterns(patterns);
|
||||||
|
@ -237,10 +201,6 @@ void TestLinalgTransforms::runOnOperation() {
|
||||||
return applyGeneralizePadTensorPatterns(getOperation());
|
return applyGeneralizePadTensorPatterns(getOperation());
|
||||||
if (testSwapSubTensorPadTensor)
|
if (testSwapSubTensorPadTensor)
|
||||||
return applyExtractSliceOfPadTensorSwapPattern(getOperation());
|
return applyExtractSliceOfPadTensorSwapPattern(getOperation());
|
||||||
if (testSplitReduction)
|
|
||||||
return applySplitReduction(getOperation());
|
|
||||||
if (testSplitReductionInnerParallel)
|
|
||||||
return applySplitReductionInnerParallel(getOperation());
|
|
||||||
if (testBubbleUpExtractSliceOpPattern)
|
if (testBubbleUpExtractSliceOpPattern)
|
||||||
return applyBubbleUpExtractSliceOpPattern(getOperation());
|
return applyBubbleUpExtractSliceOpPattern(getOperation());
|
||||||
if (testSwapExtractSliceWithFill)
|
if (testSwapExtractSliceWithFill)
|
||||||
|
|
Loading…
Reference in New Issue