[mlir][linalg] Constant fold linalg.generic that are transposes

This commit adds a pattern to perform constant folding on linalg
generic ops which are essentially transposes. We see real cases
where model importers may generate such patterns.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D110597
This commit is contained in:
Lei Zhang 2021-10-08 08:06:31 -04:00
parent 80c27abb2f
commit 4cd7ff6728
2 changed files with 375 additions and 9 deletions

View File

@ -1164,10 +1164,11 @@ private:
/// Pattern to fold a generic op with a splat constant/scalar constant. Does not
/// handle cases where the constant is not single-valued.
class FoldConstants : public OpRewritePattern<GenericOp> {
class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
public:
FoldConstants(MLIRContext *context, ControlElementwiseOpsFusionFn &fun,
PatternBenefit benefit = 1)
FoldScalarOrSplatConstant(MLIRContext *context,
ControlElementwiseOpsFusionFn &fun,
PatternBenefit benefit = 1)
: OpRewritePattern<GenericOp>(context, benefit), controlFn(fun) {}
LogicalResult matchAndRewrite(GenericOp genericOp,
@ -1268,6 +1269,237 @@ public:
private:
ControlElementwiseOpsFusionFn controlFn;
};
/// Base class for constant folding linalg.generic ops with N inputs, 1 output,
/// and permutation indexing maps.
///
/// `ConcreteType` should provide methods with signatures
///
/// ```c++
/// bool matchIndexingMaps(GenericOp genericOp) const;
/// RegionComputationFn getRegionComputeFn(GenericOp) const;
/// ```
///
/// The latter inspects the region and returns the computation inside as a
/// functor. The functor will be invoked with constant elements for all inputs
/// and should return the corresponding computea constant element for output.
template <typename ConcreteType>
class FoldConstantBase : public OpRewritePattern<GenericOp> {
public:
struct APIntOrFloatArray {
SmallVector<APInt> apInts;
SmallVector<APFloat> apFloats;
};
using RegionComputationFn =
std::function<APIntOrFloatArray(APIntOrFloatArray)>;
FoldConstantBase(MLIRContext *context,
const ControlElementwiseOpsFusionFn &controlFn,
PatternBenefit benefit = 1)
: OpRewritePattern<GenericOp>(context, benefit), controlFn(controlFn) {}
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
if (genericOp.hasBufferSemantics())
return failure();
// Only support ops generating one output for now.
if (genericOp.getNumOutputs() != 1)
return failure();
auto outputType = genericOp.getResultTypes().front().dyn_cast<ShapedType>();
// Require the output types to be static give we are generating constants.
if (!outputType || !outputType.hasStaticShape())
return failure();
if (!llvm::all_of(genericOp.getInputOperands(), [](OpOperand *operand) {
return operand->get().getType().isa<ShapedType>();
}))
return failure();
// Make sure all element types are the same.
auto getOperandElementType = [](OpOperand *operand) {
return operand->get().getType().cast<ShapedType>().getElementType();
};
if (!llvm::is_splat(llvm::map_range(genericOp.getInputAndOutputOperands(),
getOperandElementType)))
return failure();
// We can only handle the case where we have int/float elements.
auto elementType = outputType.getElementType();
if (!elementType.isIntOrFloat())
return failure();
// Require all indexing maps to be permutations for now. This is common and
// it simplifies input/output access greatly: we can do the data shuffling
// entirely in the compiler, without needing to turn all indices into
// Values, and then do affine apply on them, and then match back the
// constant again.
if (!llvm::all_of(genericOp.getIndexingMaps(),
[](AffineMap map) { return map.isPermutation(); }))
return failure();
for (OpOperand *operand : genericOp.getOutputOperands()) {
if (genericOp.payloadUsesValueFromOperand(operand))
return failure();
}
// Further check the indexing maps are okay for the ConcreteType.
if (!static_cast<const ConcreteType *>(this)->matchIndexingMaps(genericOp))
return failure();
// Defer to the concrete type to check the region and discover the
// computation inside.
RegionComputationFn computeFn =
static_cast<const ConcreteType *>(this)->getRegionComputeFn(genericOp);
if (!computeFn)
return failure();
// All inputs should be constants.
int numInputs = genericOp.getNumInputs();
SmallVector<DenseIntOrFPElementsAttr> inputValues(numInputs);
for (auto operand : llvm::enumerate(genericOp.getInputOperands())) {
if (!matchPattern(operand.value()->get(),
m_Constant(&inputValues[operand.index()])))
return failure();
}
// Identified this as a potential candidate for folding. Now check the
// policy to see whether we are allowed to proceed.
for (int i = 0; i < numInputs; ++i) {
OpOperand *consumer = genericOp.getInputOperand(i);
OpResult producer = consumer->get().cast<OpResult>();
if (!controlFn(producer, *consumer))
return failure();
}
auto linalgOp = cast<LinalgOp>(genericOp.getOperation());
SmallVector<int64_t, 4> loopBounds = linalgOp.computeStaticLoopSizes();
int64_t numElements = outputType.getNumElements();
// Use APInt/APFloat instead of Attribute here for constructing the output.
// This helps to avoid blowing up compiler memory usage: Attributes would
// unify the following cases but they have lifetime as the MLIRContext.
SmallVector<APInt> intOutputValues;
SmallVector<APFloat> fpOutputValues;
if (elementType.template isa<FloatType>())
fpOutputValues.resize(numElements, APFloat(0.f));
else
intOutputValues.resize(numElements);
// Return the constant dim positions from the given permutation map.
auto getDimPositions = [](AffineMap map) {
SmallVector<unsigned> dims;
dims.reserve(map.getNumResults());
for (AffineExpr result : map.getResults()) {
dims.push_back(result.cast<AffineDimExpr>().getPosition());
}
return dims;
};
SmallVector<SmallVector<unsigned>> inputDims;
for (int i = 0; i < numInputs; ++i)
inputDims.push_back(getDimPositions(genericOp.getIndexingMaps()[i]));
auto outputDims = getDimPositions(genericOp.getIndexingMaps().back());
auto outputShape = outputType.getShape();
// Transpose the input constant. Because we don't know its rank in advance,
// we need to loop over the range [0, element count) and delinearize the
// index.
for (int linearIndex0 = 0; linearIndex0 < numElements; ++linearIndex0) {
SmallVector<uint64_t> indices(loopBounds.size(), 0);
int totalCount = linearIndex0;
for (int dim = loopBounds.size() - 1; dim >= 0; --dim) {
indices[dim] = totalCount % loopBounds[dim];
totalCount /= loopBounds[dim];
}
SmallVector<SmallVector<uint64_t>> srcIndices;
for (int i = 0; i < numInputs; ++i)
srcIndices.emplace_back(loopBounds.size(), 0);
SmallVector<uint64_t> dstIndices(loopBounds.size(), 0);
for (int dim = loopBounds.size() - 1; dim >= 0; --dim) {
for (int i = 0; i < numInputs; ++i)
srcIndices[i][dim] = indices[inputDims[i][dim]];
dstIndices[dim] = indices[outputDims[dim]];
}
uint64_t linearIndex1 = dstIndices.front();
for (int dim = 1; dim < outputType.getRank(); ++dim)
linearIndex1 = linearIndex1 * outputShape[dim] + dstIndices[dim];
// Collect constant elements for all inputs at this loop iteration.
SmallVector<APInt> intValues;
SmallVector<APFloat> fpValues;
if (elementType.isa<FloatType>()) {
for (int i = 0; i < numInputs; ++i)
fpValues.push_back(inputValues[i].getValue<APFloat>(srcIndices[i]));
} else {
for (int i = 0; i < numInputs; ++i)
intValues.push_back(inputValues[i].getValue<APInt>(srcIndices[i]));
}
// Invoke the computation to get the corresponding constant output
// element.
APIntOrFloatArray inputs = {intValues, fpValues};
APIntOrFloatArray outputs = computeFn(inputs);
if (elementType.isa<FloatType>()) {
fpOutputValues[linearIndex1] = outputs.apFloats.front();
} else {
intOutputValues[linearIndex1] = outputs.apInts.front();
}
}
DenseIntOrFPElementsAttr outputAttr;
if (elementType.isa<FloatType>()) {
outputAttr = DenseFPElementsAttr::get(outputType, fpOutputValues);
} else {
outputAttr = DenseIntElementsAttr::get(outputType, intOutputValues);
}
rewriter.replaceOpWithNewOp<ConstantOp>(genericOp, outputAttr);
return success();
}
private:
ControlElementwiseOpsFusionFn controlFn;
};
// Folds linalg.generic ops that are actually transposes on constant values.
struct FoldConstantTranspose : public FoldConstantBase<FoldConstantTranspose> {
using FoldConstantBase::FoldConstantBase;
bool matchIndexingMaps(GenericOp genericOp) const {
// We should have one input and one output.
return genericOp.getIndexingMaps().size() == 2;
}
RegionComputationFn getRegionComputeFn(GenericOp genericOp) const {
// Make sure the region only contains a yield op.
Block &body = genericOp.region().front();
if (!llvm::hasSingleElement(body))
return nullptr;
auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
if (!yieldOp)
return nullptr;
// The yield op should return the block argument corresponds to the input.
for (Value yieldVal : yieldOp.values()) {
auto yieldArg = yieldVal.dyn_cast<BlockArgument>();
if (!yieldArg || yieldArg.getOwner() != &body)
return nullptr;
if (yieldArg.getArgNumber() != 0)
return nullptr;
}
// No computation; just return the orginal value.
return [](APIntOrFloatArray inputs) { return inputs; };
}
ControlElementwiseOpsFusionFn controlFn;
};
} // namespace
static Optional<SmallVector<Value>>
@ -1442,8 +1674,9 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
void mlir::linalg::populateElementwiseOpsFusionPatterns(
RewritePatternSet &patterns, LinalgElementwiseFusionOptions options) {
auto *context = patterns.getContext();
patterns.add<FuseElementwiseOps, FoldConstants>(
context, options.controlElementwiseOpsFusionFn);
patterns.add<FuseElementwiseOps, FoldScalarOrSplatConstant,
FoldConstantTranspose>(context,
options.controlElementwiseOpsFusionFn);
patterns.add<RemoveOutsDependency>(context);
populateFoldReshapeOpsByExpansionPatterns(patterns,
options.controlFoldingReshapesFn);

View File

@ -755,15 +755,15 @@ func @fuse_scalar_constant(%arg0 : tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<
%2:2 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> ()>,
affine_map<(d0, d1) -> ()>,
affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>],
affine_map<(d0, d1) -> ()>,
affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%arg0, %cst, %c42 : tensor<?x?xf32>, f32, i32)
outs(%0, %1 : tensor<?x?xf32>, tensor<?x?xi32>) {
^bb0(%arg1 : f32, %arg2 : f32, %arg3 : i32, %arg4 : f32, %arg5 : i32) :
%3 = addf %arg1, %arg2 : f32
linalg.yield %3, %arg3 : f32, i32
linalg.yield %3, %arg3 : f32, i32
} -> (tensor<?x?xf32>, tensor<?x?xi32>)
return %2#0, %2#1 : tensor<?x?xf32>, tensor<?x?xi32>
}
@ -774,3 +774,136 @@ func @fuse_scalar_constant(%arg0 : tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<
// CHECK-SAME: ins(%{{.+}} : tensor<?x?xf32>)
// CHECK: %[[YIELD:.+]] = addf %{{.+}}, %[[CST]] : f32
// CHECK: linalg.yield %[[YIELD]], %[[C42]] : f32, i32
// -----
// CHECK-LABEL: @transpose_fold_2d_fp32
func @transpose_fold_2d_fp32(%init: tensor<3x2xf32>) -> tensor<3x2xf32> {
%input = constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>
// CHECK: %[[CST:.+]] = constant
// CHECK-SAME{LITERAL}: dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf32>
%1 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]
} ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<3x2xf32>
// CHECK: return %[[CST]]
return %1 : tensor<3x2xf32>
}
// -----
// CHECK-LABEL: @transpose_fold_2d_fp64
func @transpose_fold_2d_fp64(%init: tensor<3x2xf64>) -> tensor<3x2xf64> {
%input = constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf64>
// CHECK: %[[CST:.+]] = constant
// CHECK-SAME{LITERAL}: dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf64>
%1 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]
} ins(%input : tensor<2x3xf64>) outs(%init : tensor<3x2xf64>) {
^bb0(%arg1: f64, %arg2: f64):
linalg.yield %arg1 : f64
} -> tensor<3x2xf64>
// CHECK: return %[[CST]]
return %1 : tensor<3x2xf64>
}
// -----
// CHECK-LABEL: @transpose_fold_4d_i32
func @transpose_fold_4d_i32(%init: tensor<3x1x4x2xi32>) -> tensor<3x1x4x2xi32> {
%input = constant dense<[[
[[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]],
[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]
]]> : tensor<1x2x3x4xi32>
// CHECK: %[[CST:.+]] = constant dense<[
// CHECK-SAME{LITERAL}: [[[0, 12], [1, 13], [2, 14], [3, 15]]],
// CHECK-SAME{LITERAL}: [[[4, 16], [5, 17], [6, 18], [7, 19]]],
// CHECK-SAME{LITERAL}: [[[8, 20], [9, 21], [10, 22], [11, 23]]]
// CHECK-SAME{LITERAL}: ]>
%1 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d0, d3, d1)>],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]
} ins(%input : tensor<1x2x3x4xi32>) outs(%init : tensor<3x1x4x2xi32>) {
^bb0(%arg1: i32, %arg2: i32):
linalg.yield %arg1 : i32
} -> tensor<3x1x4x2xi32>
// CHECK: return %[[CST]]
return %1 : tensor<3x1x4x2xi32>
}
// -----
// CHECK-LABEL: @transpose_fold_4d_i16
func @transpose_fold_4d_i16(%init: tensor<3x1x4x2xi16>) -> tensor<3x1x4x2xi16> {
%input = constant dense<[[
[[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]],
[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]
]]> : tensor<1x2x3x4xi16>
// CHECK: %[[CST:.+]] = constant dense<[
// CHECK-SAME{LITERAL}: [[[0, 12], [1, 13], [2, 14], [3, 15]]],
// CHECK-SAME{LITERAL}: [[[4, 16], [5, 17], [6, 18], [7, 19]]],
// CHECK-SAME{LITERAL}: [[[8, 20], [9, 21], [10, 22], [11, 23]]]
// CHECK-SAME{LITERAL}: ]>
%1 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d0, d3, d1)>],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]
} ins(%input : tensor<1x2x3x4xi16>) outs(%init : tensor<3x1x4x2xi16>) {
^bb0(%arg1: i16, %arg2: i16):
linalg.yield %arg1 : i16
} -> tensor<3x1x4x2xi16>
// CHECK: return %[[CST]]
return %1 : tensor<3x1x4x2xi16>
}
// -----
// CHECK-LABEL: @transpose_nofold_non_cst_input
func @transpose_nofold_non_cst_input(%input: tensor<2x3xf32>, %init: tensor<3x2xf32>) -> tensor<3x2xf32> {
// CHECK: linalg.generic
%1 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]
} ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %arg1 : f32
} -> tensor<3x2xf32>
return %1 : tensor<3x2xf32>
}
// -----
// CHECK-LABEL: @transpose_nofold_yield_const
func @transpose_nofold_yield_const(%init: tensor<3x2xf32>) -> tensor<3x2xf32> {
%input = constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>
%cst = constant 8.0 : f32
// CHECK: linalg.generic
%1 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]
} ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) {
^bb0(%arg1: f32, %arg2: f32):
linalg.yield %cst : f32
} -> tensor<3x2xf32>
return %1 : tensor<3x2xf32>
}
// -----
// CHECK-LABEL: @transpose_nofold_multi_ops_in_region
func @transpose_nofold_multi_ops_in_region(%init: tensor<3x2xf32>) -> tensor<3x2xf32> {
%input = constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>
// CHECK: linalg.generic
%1 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]
} ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) {
^bb0(%arg1: f32, %arg2: f32):
%add = addf %arg1, %arg1 : f32
linalg.yield %add : f32
} -> tensor<3x2xf32>
return %1 : tensor<3x2xf32>
}