[mlir][Linalg] Uniformize SplitReduction transforms and add option to use Bufferization::AllocTensor

This revision merges the 2 split_reduction transforms and adds extra control by using attributes.

SplitReduction is known to require a concrete additional buffer to store tempoaray information.
Add an option to introduce a `bufferization.alloc_tensor` instead of `linalg.init_tensor`.
This behaves better with subset-based tiling and bufferization.

Differential Revision: https://reviews.llvm.org/D128722
This commit is contained in:
Nicolas Vasilache 2022-06-28 05:17:32 -07:00
parent 7c4b90a98d
commit 178f9bd63c
6 changed files with 82 additions and 82 deletions

View File

@ -164,8 +164,24 @@ def SplitReductionOp : Op<Transform_Dialect, "structured.split_reduction",
reduction into a parallel and reduction dimension.
A new `linalg.generic` op is created to perform the rest of the reduction.
Example:
The transformation supports different configurations attributes:
- split_factor: the factor by which to split (i.e. the size of the
remaining reduction after splitting).
- insert_split_dimension: the dimension in the temporary tensor into
which the new parallel dimension is inserted.
- use_scaling_algorithm: whether to use a scaling based formulation that
does not create an ExpandShapeOp (default: do not use scaling)
- use_alloc: whether to use an alloc op to allocate the temporary
tensor (default: do not use alloc op)
This op returns 4 handles to:
- the init op (or tensor_alloc op if use_alloc = true),
- the fill op used to initialize the neutral element,
- the split op and
- the result-combining op.
Example (default: use_scaling_algorithm = false, use_alloc = false):
====================================================================
```
%r = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
affine_map<(d0) -> ()>],
@ -178,7 +194,7 @@ def SplitReductionOp : Op<Transform_Dialect, "structured.split_reduction",
} -> tensor<f32>
```
To:
is split into:
```
%cst = arith.constant 0.000000e+00 : f32
@ -203,34 +219,8 @@ def SplitReductionOp : Op<Transform_Dialect, "structured.split_reduction",
} -> tensor<f32>
```
This op returns handles to the fill op used to initialize the neutral
element, the split op and the result-combining op.
}];
let arguments = (ins PDL_Operation:$target,
DefaultValuedAttr<I64Attr, "{}">:$split_factor,
DefaultValuedAttr<I64Attr, "{}">:$insert_split_dimension);
let results = (outs PDL_Operation:$fill_op,
PDL_Operation:$split_linalg_op,
PDL_Operation:$combining_linalg_op);
let assemblyFormat = "$target attr-dict";
let extraClassDeclaration = [{
::mlir::FailureOr<::llvm::SmallVector<::mlir::Operation *>> applyToOne(
::mlir::linalg::LinalgOp target, TransformState &state);
}];
}
def SplitReductionByScalingOp :
Op<Transform_Dialect, "structured.split_reduction_by_scaling",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
TransformEachOpTrait, TransformOpInterface]> {
let description = [{
Indicates that the given `target` op should be transformed with the
`splitReductionByScaling` transformation and split factor provided as
attribute.
Example (use_scaling_algorithm = true, use_alloc = true):
=========================================================
Instead of introducing an ExpandShapeOp, this scaling-based implementation
rewrites a reduction dimension `k` into `k * split_factor + kk`.
The dimension `kk` is added as an extra parallel dimension to the
@ -287,12 +277,13 @@ def SplitReductionByScalingOp :
return %4 : tensor<16x32xf32>
```
}];
let arguments = (ins PDL_Operation:$target,
DefaultValuedAttr<I64Attr, "{}">:$split_factor,
DefaultValuedAttr<I64Attr, "{}">:$insert_split_dimension);
DefaultValuedAttr<I64Attr, "{}">:$insert_split_dimension,
UnitAttr:$use_scaling_algorithm,
UnitAttr:$use_alloc);
let results = (outs PDL_Operation:$fill_op,
PDL_Operation:$split_linalg_op,
PDL_Operation:$combining_linalg_op);

View File

@ -1474,7 +1474,8 @@ using ControlSplitReductionFn =
void populateSplitReductionPattern(
RewritePatternSet &patterns,
const ControlSplitReductionFn &controlSplitReductionFn,
const LinalgTransformationFilter &f = LinalgTransformationFilter());
const LinalgTransformationFilter &f = LinalgTransformationFilter(),
bool useAlloc = false);
/// Apply transformation to split the single linalg op reduction into a parallel
/// and reduction dimension. Then create a new linalg.generic op doing the rest
@ -1518,19 +1519,21 @@ void populateSplitReductionPattern(
FailureOr<LinalgOp>
splitReduction(PatternRewriter &b, LinalgOp op,
const ControlSplitReductionFn &controlSplitReductionFn,
const LinalgTransformationFilter &f);
const LinalgTransformationFilter &f, bool useAlloc = false);
/// Filterless version of the above.
/// Returns both the new linalg ops as well as the fillOp needed to initialize
/// the temporary expanded tensor with the proper neutral element.
struct SplitReductionResult {
Operation *initOrAlloc;
FillOp fillOp;
LinalgOp splitLinalgOp;
LinalgOp resultCombiningLinalgOp;
};
FailureOr<SplitReductionResult>
splitReduction(PatternRewriter &b, LinalgOp op,
const ControlSplitReductionFn &controlSplitReductionFn);
const ControlSplitReductionFn &controlSplitReductionFn,
bool useAlloc = false);
/// Scaling-based implementation of the split reduction transformation.
/// Instead of introducing an ExpandShapeOp, this rewrites a reduction dimension
@ -1580,7 +1583,8 @@ splitReduction(PatternRewriter &b, LinalgOp op,
/// ```
FailureOr<SplitReductionResult>
splitReductionByScaling(PatternRewriter &b, LinalgOp op,
const ControlSplitReductionFn &controlSplitReductionFn);
const ControlSplitReductionFn &controlSplitReductionFn,
bool useAlloc = false);
} // namespace linalg
} // namespace mlir

View File

@ -413,29 +413,9 @@ transform::SplitReductionOp::applyToOne(LinalgOp target,
SimpleRewriter rewriter(getContext());
rewriter.setInsertionPoint(target);
FailureOr<SplitReductionResult> splitResult =
splitReduction(rewriter, target, splitFn);
if (failed(splitResult))
return getOperation()->emitError("failed to apply");
return SmallVector<Operation *>{splitResult->fillOp,
splitResult->splitLinalgOp,
splitResult->resultCombiningLinalgOp};
}
//===----------------------------------------------------------------------===//
// SplitReductionByScalingOp
//===----------------------------------------------------------------------===//
FailureOr<SmallVector<Operation *>>
transform::SplitReductionByScalingOp::applyToOne(LinalgOp target,
TransformState &state) {
ControlSplitReductionFn splitFn = [&](LinalgOp) {
return std::pair<int64_t, unsigned>(getSplitFactor(),
getInsertSplitDimension());
};
SimpleRewriter rewriter(getContext());
rewriter.setInsertionPoint(target);
FailureOr<SplitReductionResult> splitResult =
splitReductionByScaling(rewriter, target, splitFn);
(getUseScalingAlgorithm())
? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc())
: splitReduction(rewriter, target, splitFn, getUseAlloc());
if (failed(splitResult))
return getOperation()->emitError("failed to apply");
return SmallVector<Operation *>{splitResult->fillOp,

View File

@ -15,6 +15,7 @@
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
@ -60,14 +61,14 @@ static Attribute getNeutralElement(Operation *op) {
FailureOr<LinalgOp> mlir::linalg::splitReduction(
PatternRewriter &b, LinalgOp op,
const ControlSplitReductionFn &controlSplitReductionFn,
const LinalgTransformationFilter &filter) {
const LinalgTransformationFilter &filter, bool useAlloc) {
if (failed(filter.checkAndNotify(b, op)) || !op.hasTensorSemantics() ||
op.getNumReductionLoops() != 1 || op.getNumOutputs() != 1 ||
!op.hasOnlyProjectedPermutations())
return b.notifyMatchFailure(op, "precondition not met");
FailureOr<SplitReductionResult> res =
splitReduction(b, op, controlSplitReductionFn);
splitReduction(b, op, controlSplitReductionFn, useAlloc);
if (failed(res))
return failure();
@ -79,7 +80,7 @@ FailureOr<LinalgOp> mlir::linalg::splitReduction(
FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
PatternRewriter &b, LinalgOp op,
const ControlSplitReductionFn &controlSplitReductionFn) {
const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) {
OpBuilder::InsertionGuard guard(b);
b.setInsertionPoint(op);
@ -171,11 +172,20 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
outputExpr.push_back(
b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
}
Value initTensor = b.create<linalg::InitTensorOp>(
loc, newOutputShape, op.getRegionOutputArgs()[0].getType());
Value initOrAllocTensor;
if (useAlloc) {
initOrAllocTensor = b.create<bufferization::AllocTensorOp>(
loc,
RankedTensorType::get(newOutputShape,
op.getRegionOutputArgs()[0].getType()),
ValueRange{});
} else {
initOrAllocTensor = b.create<linalg::InitTensorOp>(
loc, newOutputShape, op.getRegionOutputArgs()[0].getType());
}
Value constantOp = b.create<arith::ConstantOp>(loc, identity);
Value identityTensor =
b.create<linalg::FillOp>(op->getLoc(), constantOp, initTensor)
b.create<linalg::FillOp>(op->getLoc(), constantOp, initOrAllocTensor)
.getResult(0);
newMaps.push_back(AffineMap::get(oldOutputMap.getNumDims() + 1, 0, outputExpr,
@ -189,7 +199,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
// Create the new op matching the original op with an extra parallel
// dimension.
GenericOp genericOp = b.create<GenericOp>(
loc, TypeRange({initTensor.getType()}), newInputs,
loc, TypeRange({initOrAllocTensor.getType()}), newInputs,
ValueRange({identityTensor}), newMaps, newIteratorTypes);
b.inlineRegionBefore(op->getRegion(0), genericOp.region(),
genericOp.region().begin());
@ -223,9 +233,9 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
});
b.replaceOp(op, reduction.getResults());
return SplitReductionResult{identityTensor.getDefiningOp<FillOp>(),
cast<LinalgOp>(genericOp.getOperation()),
reduction};
return SplitReductionResult{
initOrAllocTensor.getDefiningOp(), identityTensor.getDefiningOp<FillOp>(),
cast<LinalgOp>(genericOp.getOperation()), reduction};
}
/// Rewrite f(i, j, k, ...) into f(i, j, k * ratio + kk, ...)
@ -260,7 +270,7 @@ static AffineMap insertParallelDim(LinalgOp op, OpOperand &opOperand,
/// Core rewrite implementation.
FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
PatternRewriter &b, LinalgOp op,
const ControlSplitReductionFn &controlSplitReductionFn) {
const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) {
OpBuilder::InsertionGuard guard(b);
b.setInsertionPoint(op);
@ -297,7 +307,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
return b.notifyMatchFailure(op, "unknown reduction neutral");
// TODO: relax this when multi-reduction support is available.
if (op.getNumOutputs() != (int)neutralElements.size())
if (op.getNumOutputs() != static_cast<int64_t>(neutralElements.size()))
return b.notifyMatchFailure(op, "expect one reduction per output");
// Rewrite part.
@ -318,6 +328,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
// TODO: generalize when multi-reduction support is available.
SmallVector<Value> newOutputs;
newOutputs.reserve(op.getNumOutputs());
SmallVector<Operation *> initOrAllocTensorOps;
SmallVector<linalg::FillOp> fillOps;
fillOps.reserve(op.getNumOutputs());
for (auto it : llvm::zip(op.outputs(), neutralElements)) {
@ -327,12 +338,19 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
reductionDimSize / splitFactor, insertSplitDimension);
SmallVector<Value> dims =
tensor::createDynamicDimValues(b, loc, rankedTensor);
Value initTensor = b.create<linalg::InitTensorOp>(
loc, dims, newT.getShape(), t.getElementType());
Value initOrAllocTensor;
if (useAlloc) {
initOrAllocTensor =
b.create<bufferization::AllocTensorOp>(loc, newT, dims);
} else {
initOrAllocTensor = b.create<linalg::InitTensorOp>(
loc, dims, newT.getShape(), t.getElementType());
}
Value constantOp = b.create<arith::ConstantOp>(loc, std::get<1>(it));
fillOps.push_back(
b.create<linalg::FillOp>(op->getLoc(), constantOp, initTensor));
b.create<linalg::FillOp>(op->getLoc(), constantOp, initOrAllocTensor));
newOutputs.push_back(fillOps.back().getResult(0));
initOrAllocTensorOps.push_back(initOrAllocTensor.getDefiningOp());
}
// Step 2. Reindex / expand indexing maps.
@ -423,7 +441,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
// TODO: extend when multi-reduction support is available.
assert(fillOps.size() == results.size() && results.size() == 1);
b.replaceOp(op, results.front()->getResults());
return SplitReductionResult{fillOps.front(),
return SplitReductionResult{initOrAllocTensorOps.front(), fillOps.front(),
cast<LinalgOp>(genericOp.getOperation()),
results.front()};
}
@ -434,18 +452,21 @@ struct LinalgSplitReduction : public OpInterfaceRewritePattern<LinalgOp> {
/// Construct a generic pattern applied to all LinalgOp that verify `filter`.
LinalgSplitReduction(MLIRContext *context,
ControlSplitReductionFn controlSplitReductionFn,
LinalgTransformationFilter f, PatternBenefit benefit = 1)
LinalgTransformationFilter f, bool useAlloc = false,
PatternBenefit benefit = 1)
: OpInterfaceRewritePattern<LinalgOp>(context, benefit),
controlSplitReductionFn(std::move(controlSplitReductionFn)),
filter(std::move(f)) {}
useAlloc(useAlloc), filter(std::move(f)) {}
LogicalResult matchAndRewrite(LinalgOp op,
PatternRewriter &rewriter) const override {
return splitReduction(rewriter, op, controlSplitReductionFn, filter);
return splitReduction(rewriter, op, controlSplitReductionFn, filter,
useAlloc);
}
private:
ControlSplitReductionFn controlSplitReductionFn;
bool useAlloc;
LinalgTransformationFilter filter;
};
@ -454,7 +475,7 @@ private:
void linalg::populateSplitReductionPattern(
RewritePatternSet &patterns,
const ControlSplitReductionFn &controlSplitReductionFn,
const LinalgTransformationFilter &f) {
const LinalgTransformationFilter &f, bool useAlloc) {
patterns.add<LinalgSplitReduction>(patterns.getContext(),
controlSplitReductionFn, f);
controlSplitReductionFn, f, useAlloc);
}

View File

@ -3,6 +3,7 @@
// CHECK-LABEL: func.func @matmul_split
func.func @matmul_split(%A : tensor<?x256xf32>, %B: tensor<256x32xf32>, %C: tensor<?x32xf32>) -> tensor<?x32xf32> {
// CHECK: bufferization.alloc_tensor({{.*}}) : tensor<?x32x64xf32>
// CHECK: linalg.generic
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
// CHECK-SAME: ins(%{{[a-zA-Z0-9]*}}, %{{[a-zA-Z0-9]*}}, %{{[a-zA-Z0-9]*}} : tensor<?x256xf32>, tensor<256x32xf32>, tensor<64x4xi1>)
@ -30,6 +31,7 @@ transform.with_pdl_patterns {
transform.sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = pdl_match @pdl_target in %arg1
%1:3 = transform.structured.split_reduction_by_scaling %0 { split_factor = 4, insert_split_dimension = 2}
%1:3 = transform.structured.split_reduction %0
{ split_factor = 4, insert_split_dimension = 2, use_scaling_algorithm, use_alloc}
}
}

View File

@ -12,6 +12,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@ -41,6 +42,7 @@ struct TestLinalgTransforms
void getDependentDialects(DialectRegistry &registry) const override {
// clang-format off
registry.insert<AffineDialect,
bufferization::BufferizationDialect,
memref::MemRefDialect,
scf::SCFDialect,
linalg::LinalgDialect,