forked from OSchip/llvm-project
[mlir][Linalg] Uniformize SplitReduction transforms and add option to use Bufferization::AllocTensor
This revision merges the 2 split_reduction transforms and adds extra control by using attributes. SplitReduction is known to require a concrete additional buffer to store tempoaray information. Add an option to introduce a `bufferization.alloc_tensor` instead of `linalg.init_tensor`. This behaves better with subset-based tiling and bufferization. Differential Revision: https://reviews.llvm.org/D128722
This commit is contained in:
parent
7c4b90a98d
commit
178f9bd63c
|
@ -164,8 +164,24 @@ def SplitReductionOp : Op<Transform_Dialect, "structured.split_reduction",
|
|||
reduction into a parallel and reduction dimension.
|
||||
A new `linalg.generic` op is created to perform the rest of the reduction.
|
||||
|
||||
Example:
|
||||
|
||||
The transformation supports different configurations attributes:
|
||||
- split_factor: the factor by which to split (i.e. the size of the
|
||||
remaining reduction after splitting).
|
||||
- insert_split_dimension: the dimension in the temporary tensor into
|
||||
which the new parallel dimension is inserted.
|
||||
- use_scaling_algorithm: whether to use a scaling based formulation that
|
||||
does not create an ExpandShapeOp (default: do not use scaling)
|
||||
- use_alloc: whether to use an alloc op to allocate the temporary
|
||||
tensor (default: do not use alloc op)
|
||||
|
||||
This op returns 4 handles to:
|
||||
- the init op (or tensor_alloc op if use_alloc = true),
|
||||
- the fill op used to initialize the neutral element,
|
||||
- the split op and
|
||||
- the result-combining op.
|
||||
|
||||
Example (default: use_scaling_algorithm = false, use_alloc = false):
|
||||
====================================================================
|
||||
```
|
||||
%r = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
|
||||
affine_map<(d0) -> ()>],
|
||||
|
@ -178,7 +194,7 @@ def SplitReductionOp : Op<Transform_Dialect, "structured.split_reduction",
|
|||
} -> tensor<f32>
|
||||
```
|
||||
|
||||
To:
|
||||
is split into:
|
||||
|
||||
```
|
||||
%cst = arith.constant 0.000000e+00 : f32
|
||||
|
@ -203,34 +219,8 @@ def SplitReductionOp : Op<Transform_Dialect, "structured.split_reduction",
|
|||
} -> tensor<f32>
|
||||
```
|
||||
|
||||
This op returns handles to the fill op used to initialize the neutral
|
||||
element, the split op and the result-combining op.
|
||||
}];
|
||||
|
||||
let arguments = (ins PDL_Operation:$target,
|
||||
DefaultValuedAttr<I64Attr, "{}">:$split_factor,
|
||||
DefaultValuedAttr<I64Attr, "{}">:$insert_split_dimension);
|
||||
let results = (outs PDL_Operation:$fill_op,
|
||||
PDL_Operation:$split_linalg_op,
|
||||
PDL_Operation:$combining_linalg_op);
|
||||
|
||||
let assemblyFormat = "$target attr-dict";
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::FailureOr<::llvm::SmallVector<::mlir::Operation *>> applyToOne(
|
||||
::mlir::linalg::LinalgOp target, TransformState &state);
|
||||
}];
|
||||
}
|
||||
|
||||
def SplitReductionByScalingOp :
|
||||
Op<Transform_Dialect, "structured.split_reduction_by_scaling",
|
||||
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
|
||||
TransformEachOpTrait, TransformOpInterface]> {
|
||||
let description = [{
|
||||
Indicates that the given `target` op should be transformed with the
|
||||
`splitReductionByScaling` transformation and split factor provided as
|
||||
attribute.
|
||||
|
||||
Example (use_scaling_algorithm = true, use_alloc = true):
|
||||
=========================================================
|
||||
Instead of introducing an ExpandShapeOp, this scaling-based implementation
|
||||
rewrites a reduction dimension `k` into `k * split_factor + kk`.
|
||||
The dimension `kk` is added as an extra parallel dimension to the
|
||||
|
@ -287,12 +277,13 @@ def SplitReductionByScalingOp :
|
|||
|
||||
return %4 : tensor<16x32xf32>
|
||||
```
|
||||
|
||||
}];
|
||||
|
||||
let arguments = (ins PDL_Operation:$target,
|
||||
DefaultValuedAttr<I64Attr, "{}">:$split_factor,
|
||||
DefaultValuedAttr<I64Attr, "{}">:$insert_split_dimension);
|
||||
DefaultValuedAttr<I64Attr, "{}">:$insert_split_dimension,
|
||||
UnitAttr:$use_scaling_algorithm,
|
||||
UnitAttr:$use_alloc);
|
||||
let results = (outs PDL_Operation:$fill_op,
|
||||
PDL_Operation:$split_linalg_op,
|
||||
PDL_Operation:$combining_linalg_op);
|
||||
|
|
|
@ -1474,7 +1474,8 @@ using ControlSplitReductionFn =
|
|||
void populateSplitReductionPattern(
|
||||
RewritePatternSet &patterns,
|
||||
const ControlSplitReductionFn &controlSplitReductionFn,
|
||||
const LinalgTransformationFilter &f = LinalgTransformationFilter());
|
||||
const LinalgTransformationFilter &f = LinalgTransformationFilter(),
|
||||
bool useAlloc = false);
|
||||
|
||||
/// Apply transformation to split the single linalg op reduction into a parallel
|
||||
/// and reduction dimension. Then create a new linalg.generic op doing the rest
|
||||
|
@ -1518,19 +1519,21 @@ void populateSplitReductionPattern(
|
|||
FailureOr<LinalgOp>
|
||||
splitReduction(PatternRewriter &b, LinalgOp op,
|
||||
const ControlSplitReductionFn &controlSplitReductionFn,
|
||||
const LinalgTransformationFilter &f);
|
||||
const LinalgTransformationFilter &f, bool useAlloc = false);
|
||||
|
||||
/// Filterless version of the above.
|
||||
/// Returns both the new linalg ops as well as the fillOp needed to initialize
|
||||
/// the temporary expanded tensor with the proper neutral element.
|
||||
struct SplitReductionResult {
|
||||
Operation *initOrAlloc;
|
||||
FillOp fillOp;
|
||||
LinalgOp splitLinalgOp;
|
||||
LinalgOp resultCombiningLinalgOp;
|
||||
};
|
||||
FailureOr<SplitReductionResult>
|
||||
splitReduction(PatternRewriter &b, LinalgOp op,
|
||||
const ControlSplitReductionFn &controlSplitReductionFn);
|
||||
const ControlSplitReductionFn &controlSplitReductionFn,
|
||||
bool useAlloc = false);
|
||||
|
||||
/// Scaling-based implementation of the split reduction transformation.
|
||||
/// Instead of introducing an ExpandShapeOp, this rewrites a reduction dimension
|
||||
|
@ -1580,7 +1583,8 @@ splitReduction(PatternRewriter &b, LinalgOp op,
|
|||
/// ```
|
||||
FailureOr<SplitReductionResult>
|
||||
splitReductionByScaling(PatternRewriter &b, LinalgOp op,
|
||||
const ControlSplitReductionFn &controlSplitReductionFn);
|
||||
const ControlSplitReductionFn &controlSplitReductionFn,
|
||||
bool useAlloc = false);
|
||||
|
||||
} // namespace linalg
|
||||
} // namespace mlir
|
||||
|
|
|
@ -413,29 +413,9 @@ transform::SplitReductionOp::applyToOne(LinalgOp target,
|
|||
SimpleRewriter rewriter(getContext());
|
||||
rewriter.setInsertionPoint(target);
|
||||
FailureOr<SplitReductionResult> splitResult =
|
||||
splitReduction(rewriter, target, splitFn);
|
||||
if (failed(splitResult))
|
||||
return getOperation()->emitError("failed to apply");
|
||||
return SmallVector<Operation *>{splitResult->fillOp,
|
||||
splitResult->splitLinalgOp,
|
||||
splitResult->resultCombiningLinalgOp};
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SplitReductionByScalingOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
FailureOr<SmallVector<Operation *>>
|
||||
transform::SplitReductionByScalingOp::applyToOne(LinalgOp target,
|
||||
TransformState &state) {
|
||||
ControlSplitReductionFn splitFn = [&](LinalgOp) {
|
||||
return std::pair<int64_t, unsigned>(getSplitFactor(),
|
||||
getInsertSplitDimension());
|
||||
};
|
||||
SimpleRewriter rewriter(getContext());
|
||||
rewriter.setInsertionPoint(target);
|
||||
FailureOr<SplitReductionResult> splitResult =
|
||||
splitReductionByScaling(rewriter, target, splitFn);
|
||||
(getUseScalingAlgorithm())
|
||||
? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc())
|
||||
: splitReduction(rewriter, target, splitFn, getUseAlloc());
|
||||
if (failed(splitResult))
|
||||
return getOperation()->emitError("failed to apply");
|
||||
return SmallVector<Operation *>{splitResult->fillOp,
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
#include "mlir/Analysis/SliceAnalysis.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||
|
@ -60,14 +61,14 @@ static Attribute getNeutralElement(Operation *op) {
|
|||
FailureOr<LinalgOp> mlir::linalg::splitReduction(
|
||||
PatternRewriter &b, LinalgOp op,
|
||||
const ControlSplitReductionFn &controlSplitReductionFn,
|
||||
const LinalgTransformationFilter &filter) {
|
||||
const LinalgTransformationFilter &filter, bool useAlloc) {
|
||||
if (failed(filter.checkAndNotify(b, op)) || !op.hasTensorSemantics() ||
|
||||
op.getNumReductionLoops() != 1 || op.getNumOutputs() != 1 ||
|
||||
!op.hasOnlyProjectedPermutations())
|
||||
return b.notifyMatchFailure(op, "precondition not met");
|
||||
|
||||
FailureOr<SplitReductionResult> res =
|
||||
splitReduction(b, op, controlSplitReductionFn);
|
||||
splitReduction(b, op, controlSplitReductionFn, useAlloc);
|
||||
if (failed(res))
|
||||
return failure();
|
||||
|
||||
|
@ -79,7 +80,7 @@ FailureOr<LinalgOp> mlir::linalg::splitReduction(
|
|||
|
||||
FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
|
||||
PatternRewriter &b, LinalgOp op,
|
||||
const ControlSplitReductionFn &controlSplitReductionFn) {
|
||||
const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) {
|
||||
OpBuilder::InsertionGuard guard(b);
|
||||
b.setInsertionPoint(op);
|
||||
|
||||
|
@ -171,11 +172,20 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
|
|||
outputExpr.push_back(
|
||||
b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
|
||||
}
|
||||
Value initTensor = b.create<linalg::InitTensorOp>(
|
||||
loc, newOutputShape, op.getRegionOutputArgs()[0].getType());
|
||||
Value initOrAllocTensor;
|
||||
if (useAlloc) {
|
||||
initOrAllocTensor = b.create<bufferization::AllocTensorOp>(
|
||||
loc,
|
||||
RankedTensorType::get(newOutputShape,
|
||||
op.getRegionOutputArgs()[0].getType()),
|
||||
ValueRange{});
|
||||
} else {
|
||||
initOrAllocTensor = b.create<linalg::InitTensorOp>(
|
||||
loc, newOutputShape, op.getRegionOutputArgs()[0].getType());
|
||||
}
|
||||
Value constantOp = b.create<arith::ConstantOp>(loc, identity);
|
||||
Value identityTensor =
|
||||
b.create<linalg::FillOp>(op->getLoc(), constantOp, initTensor)
|
||||
b.create<linalg::FillOp>(op->getLoc(), constantOp, initOrAllocTensor)
|
||||
.getResult(0);
|
||||
|
||||
newMaps.push_back(AffineMap::get(oldOutputMap.getNumDims() + 1, 0, outputExpr,
|
||||
|
@ -189,7 +199,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
|
|||
// Create the new op matching the original op with an extra parallel
|
||||
// dimension.
|
||||
GenericOp genericOp = b.create<GenericOp>(
|
||||
loc, TypeRange({initTensor.getType()}), newInputs,
|
||||
loc, TypeRange({initOrAllocTensor.getType()}), newInputs,
|
||||
ValueRange({identityTensor}), newMaps, newIteratorTypes);
|
||||
b.inlineRegionBefore(op->getRegion(0), genericOp.region(),
|
||||
genericOp.region().begin());
|
||||
|
@ -223,9 +233,9 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
|
|||
});
|
||||
b.replaceOp(op, reduction.getResults());
|
||||
|
||||
return SplitReductionResult{identityTensor.getDefiningOp<FillOp>(),
|
||||
cast<LinalgOp>(genericOp.getOperation()),
|
||||
reduction};
|
||||
return SplitReductionResult{
|
||||
initOrAllocTensor.getDefiningOp(), identityTensor.getDefiningOp<FillOp>(),
|
||||
cast<LinalgOp>(genericOp.getOperation()), reduction};
|
||||
}
|
||||
|
||||
/// Rewrite f(i, j, k, ...) into f(i, j, k * ratio + kk, ...)
|
||||
|
@ -260,7 +270,7 @@ static AffineMap insertParallelDim(LinalgOp op, OpOperand &opOperand,
|
|||
/// Core rewrite implementation.
|
||||
FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
|
||||
PatternRewriter &b, LinalgOp op,
|
||||
const ControlSplitReductionFn &controlSplitReductionFn) {
|
||||
const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) {
|
||||
OpBuilder::InsertionGuard guard(b);
|
||||
b.setInsertionPoint(op);
|
||||
|
||||
|
@ -297,7 +307,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
|
|||
return b.notifyMatchFailure(op, "unknown reduction neutral");
|
||||
|
||||
// TODO: relax this when multi-reduction support is available.
|
||||
if (op.getNumOutputs() != (int)neutralElements.size())
|
||||
if (op.getNumOutputs() != static_cast<int64_t>(neutralElements.size()))
|
||||
return b.notifyMatchFailure(op, "expect one reduction per output");
|
||||
|
||||
// Rewrite part.
|
||||
|
@ -318,6 +328,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
|
|||
// TODO: generalize when multi-reduction support is available.
|
||||
SmallVector<Value> newOutputs;
|
||||
newOutputs.reserve(op.getNumOutputs());
|
||||
SmallVector<Operation *> initOrAllocTensorOps;
|
||||
SmallVector<linalg::FillOp> fillOps;
|
||||
fillOps.reserve(op.getNumOutputs());
|
||||
for (auto it : llvm::zip(op.outputs(), neutralElements)) {
|
||||
|
@ -327,12 +338,19 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
|
|||
reductionDimSize / splitFactor, insertSplitDimension);
|
||||
SmallVector<Value> dims =
|
||||
tensor::createDynamicDimValues(b, loc, rankedTensor);
|
||||
Value initTensor = b.create<linalg::InitTensorOp>(
|
||||
loc, dims, newT.getShape(), t.getElementType());
|
||||
Value initOrAllocTensor;
|
||||
if (useAlloc) {
|
||||
initOrAllocTensor =
|
||||
b.create<bufferization::AllocTensorOp>(loc, newT, dims);
|
||||
} else {
|
||||
initOrAllocTensor = b.create<linalg::InitTensorOp>(
|
||||
loc, dims, newT.getShape(), t.getElementType());
|
||||
}
|
||||
Value constantOp = b.create<arith::ConstantOp>(loc, std::get<1>(it));
|
||||
fillOps.push_back(
|
||||
b.create<linalg::FillOp>(op->getLoc(), constantOp, initTensor));
|
||||
b.create<linalg::FillOp>(op->getLoc(), constantOp, initOrAllocTensor));
|
||||
newOutputs.push_back(fillOps.back().getResult(0));
|
||||
initOrAllocTensorOps.push_back(initOrAllocTensor.getDefiningOp());
|
||||
}
|
||||
|
||||
// Step 2. Reindex / expand indexing maps.
|
||||
|
@ -423,7 +441,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
|
|||
// TODO: extend when multi-reduction support is available.
|
||||
assert(fillOps.size() == results.size() && results.size() == 1);
|
||||
b.replaceOp(op, results.front()->getResults());
|
||||
return SplitReductionResult{fillOps.front(),
|
||||
return SplitReductionResult{initOrAllocTensorOps.front(), fillOps.front(),
|
||||
cast<LinalgOp>(genericOp.getOperation()),
|
||||
results.front()};
|
||||
}
|
||||
|
@ -434,18 +452,21 @@ struct LinalgSplitReduction : public OpInterfaceRewritePattern<LinalgOp> {
|
|||
/// Construct a generic pattern applied to all LinalgOp that verify `filter`.
|
||||
LinalgSplitReduction(MLIRContext *context,
|
||||
ControlSplitReductionFn controlSplitReductionFn,
|
||||
LinalgTransformationFilter f, PatternBenefit benefit = 1)
|
||||
LinalgTransformationFilter f, bool useAlloc = false,
|
||||
PatternBenefit benefit = 1)
|
||||
: OpInterfaceRewritePattern<LinalgOp>(context, benefit),
|
||||
controlSplitReductionFn(std::move(controlSplitReductionFn)),
|
||||
filter(std::move(f)) {}
|
||||
useAlloc(useAlloc), filter(std::move(f)) {}
|
||||
|
||||
LogicalResult matchAndRewrite(LinalgOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
return splitReduction(rewriter, op, controlSplitReductionFn, filter);
|
||||
return splitReduction(rewriter, op, controlSplitReductionFn, filter,
|
||||
useAlloc);
|
||||
}
|
||||
|
||||
private:
|
||||
ControlSplitReductionFn controlSplitReductionFn;
|
||||
bool useAlloc;
|
||||
LinalgTransformationFilter filter;
|
||||
};
|
||||
|
||||
|
@ -454,7 +475,7 @@ private:
|
|||
void linalg::populateSplitReductionPattern(
|
||||
RewritePatternSet &patterns,
|
||||
const ControlSplitReductionFn &controlSplitReductionFn,
|
||||
const LinalgTransformationFilter &f) {
|
||||
const LinalgTransformationFilter &f, bool useAlloc) {
|
||||
patterns.add<LinalgSplitReduction>(patterns.getContext(),
|
||||
controlSplitReductionFn, f);
|
||||
controlSplitReductionFn, f, useAlloc);
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
// CHECK-LABEL: func.func @matmul_split
|
||||
func.func @matmul_split(%A : tensor<?x256xf32>, %B: tensor<256x32xf32>, %C: tensor<?x32xf32>) -> tensor<?x32xf32> {
|
||||
|
||||
// CHECK: bufferization.alloc_tensor({{.*}}) : tensor<?x32x64xf32>
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
|
||||
// CHECK-SAME: ins(%{{[a-zA-Z0-9]*}}, %{{[a-zA-Z0-9]*}}, %{{[a-zA-Z0-9]*}} : tensor<?x256xf32>, tensor<256x32xf32>, tensor<64x4xi1>)
|
||||
|
@ -30,6 +31,7 @@ transform.with_pdl_patterns {
|
|||
transform.sequence %arg0 {
|
||||
^bb1(%arg1: !pdl.operation):
|
||||
%0 = pdl_match @pdl_target in %arg1
|
||||
%1:3 = transform.structured.split_reduction_by_scaling %0 { split_factor = 4, insert_split_dimension = 2}
|
||||
%1:3 = transform.structured.split_reduction %0
|
||||
{ split_factor = 4, insert_split_dimension = 2, use_scaling_algorithm, use_alloc}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
|
@ -41,6 +42,7 @@ struct TestLinalgTransforms
|
|||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
// clang-format off
|
||||
registry.insert<AffineDialect,
|
||||
bufferization::BufferizationDialect,
|
||||
memref::MemRefDialect,
|
||||
scf::SCFDialect,
|
||||
linalg::LinalgDialect,
|
||||
|
|
Loading…
Reference in New Issue