forked from OSchip/llvm-project
1763 lines
74 KiB
C++
1763 lines
74 KiB
C++
//===- ElementwiseOpFusion.cpp - Implementation of linalg Fusion ---------===///
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements the linalg dialect Fusion on tensors operations pass.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
#include "PassDetail.h"
|
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
|
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
|
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
|
#include "mlir/Dialect/Linalg/Passes.h"
|
|
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
|
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
|
#include "mlir/IR/AffineExpr.h"
|
|
#include "mlir/IR/AffineMap.h"
|
|
#include "mlir/IR/Matchers.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Support/LLVM.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::linalg;
|
|
|
|
/// Conditions for elementwise fusion of generic operations.
|
|
static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer,
|
|
OpOperand *consumerOpOperand) {
|
|
// Producer and consumer must have tensor semantics.
|
|
if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics())
|
|
return false;
|
|
|
|
// Verify that
|
|
// - the producer has all "parallel" iterator type.
|
|
if (producer.getNumParallelLoops() != producer.getNumLoops())
|
|
return false;
|
|
|
|
// Only allow fusing the producer of an input operand for now.
|
|
// TODO: allow fusing the producer of an output operand.
|
|
if (!consumer.isInputTensor(consumerOpOperand))
|
|
return false;
|
|
|
|
// Get the consumer index map. The number of results of the consumer index
|
|
// map must match the number of loops of the producer.
|
|
AffineMap consumerIndexMap = consumer.getTiedIndexingMap(consumerOpOperand);
|
|
if (consumerIndexMap.getNumResults() != producer.getNumLoops())
|
|
return false;
|
|
|
|
// Currently support only operations with single result.
|
|
if (producer.getNumOutputs() != 1)
|
|
return false;
|
|
|
|
// Finally the index_map for the result must be invertible. For now just
|
|
// verify it is a permutation.
|
|
AffineMap producerResultIndexMap =
|
|
producer.getTiedIndexingMap(producer.getOutputOperand(0));
|
|
return producerResultIndexMap.isPermutation();
|
|
}
|
|
|
|
/// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
|
|
/// the `producer` to use in the fused operation given the indexing map of the
|
|
/// result of the producer in the consumer.
|
|
static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
|
|
OpOperand *producerOpOperand, AffineMap producerResultIndexMap,
|
|
AffineMap fusedConsumerArgIndexMap) {
|
|
// The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
|
|
// from consumer loop -> consumer arg tensor index/producer result tensor
|
|
// index. The fused loop is same as the consumer loop. For each producer arg
|
|
// the indexing map to be computed is a map from consumer loop -> producer
|
|
// arg tensor index.
|
|
// producerResultIndexMap is a map from producer loop -> tensor index.
|
|
// Compute the inverse to get map from tensor index -> producer loop.
|
|
// The inverse is a map from producer result tensor index -> producer loop.
|
|
AffineMap invProducerResultIndexMap =
|
|
inversePermutation(producerResultIndexMap);
|
|
assert(invProducerResultIndexMap &&
|
|
"expected producer result indexig map to be invertible");
|
|
|
|
LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner());
|
|
// argMap is a map from producer loop -> producer arg tensor index.
|
|
AffineMap argMap = producer.getTiedIndexingMap(producerOpOperand);
|
|
|
|
// Compose argMap with invProducerResultIndexMap to get a map from
|
|
// producer result tensor index -> producer arg tensor index.
|
|
AffineMap t1 = argMap.compose(invProducerResultIndexMap);
|
|
|
|
// Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
|
|
// consumer loop/ fused loop -> producer arg tensor index.
|
|
return t1.compose(fusedConsumerArgIndexMap);
|
|
}
|
|
|
|
/// Generate the region of the fused tensor operation. The region of the fused
|
|
/// op must be empty.
|
|
static void
|
|
generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp,
|
|
AffineMap consumerToProducerLoopsMap,
|
|
OpOperand *consumerOpOperand,
|
|
unsigned nloops) {
|
|
auto producer = cast<GenericOp>(consumerOpOperand->get().getDefiningOp());
|
|
auto consumer = cast<GenericOp>(consumerOpOperand->getOwner());
|
|
// Build the region of the fused op.
|
|
Block &producerBlock = producer->getRegion(0).front();
|
|
Block &consumerBlock = consumer->getRegion(0).front();
|
|
Block *fusedBlock = new Block();
|
|
fusedOp.region().push_back(fusedBlock);
|
|
BlockAndValueMapping mapper;
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
rewriter.setInsertionPointToStart(fusedBlock);
|
|
|
|
// 2. Add an index operation for every fused loop dimension and use the
|
|
// `consumerToProducerLoopsMap` to map the producer indices.
|
|
if (producer.hasIndexSemantics()) {
|
|
// Add an index operation for every fused loop dimension.
|
|
unsigned numFusedOpLoops =
|
|
std::max(producer.getNumLoops(), consumer.getNumLoops());
|
|
SmallVector<Value> fusedIndices;
|
|
fusedIndices.reserve(numFusedOpLoops);
|
|
llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),
|
|
std::back_inserter(fusedIndices), [&](uint64_t dim) {
|
|
return rewriter.create<IndexOp>(producer.getLoc(), dim);
|
|
});
|
|
for (IndexOp indexOp :
|
|
llvm::make_early_inc_range(producerBlock.getOps<IndexOp>())) {
|
|
Value newIndex = rewriter.create<mlir::AffineApplyOp>(
|
|
producer.getLoc(),
|
|
consumerToProducerLoopsMap.getSubMap(indexOp.dim()), fusedIndices);
|
|
mapper.map(indexOp.getResult(), newIndex);
|
|
}
|
|
}
|
|
// TODO: allow fusing the producer of an output operand.
|
|
assert(consumer.isInputTensor(consumerOpOperand) &&
|
|
"expected producer of input operand");
|
|
// 3. Consumer input operands up to consumerIdx (exclusive).
|
|
for (BlockArgument bbArg : consumerBlock.getArguments().take_front(
|
|
consumerOpOperand->getOperandNumber())) // input assumption.
|
|
mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
|
|
|
|
// Replacing consumerIdx requires getting the cloned, yielded, value from
|
|
// the (cloned) producer block. This happens in step 9.
|
|
|
|
// 4. Splice in producer's input operands.
|
|
for (BlockArgument bbArg :
|
|
producerBlock.getArguments().take_front(producer.getNumInputs()))
|
|
mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
|
|
|
|
// 4.b. Producer output operand/map that is fused needs to be mapped to the
|
|
// producer bbArg if it is an "initTensor" (i.e. its value is actually read).
|
|
assert(producer->getNumResults() == 1 && "expected single result producer");
|
|
if (producer.isInitTensor(producer.getOutputOperand(0))) {
|
|
BlockArgument bbArg = producerBlock.getArguments()
|
|
.drop_front(producer.getNumInputs())
|
|
// TODO: bbArg index of
|
|
.front();
|
|
mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
|
|
}
|
|
// 5. Remaining consumer's input operands (drop past index `consumerIdx`).
|
|
for (BlockArgument bbArg :
|
|
consumerBlock.getArguments()
|
|
.take_front(consumer.getNumInputs())
|
|
.drop_front(consumerOpOperand->getOperandNumber() + 1))
|
|
mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
|
|
// 6. All of consumer's output operands.
|
|
for (BlockArgument bbArg :
|
|
consumerBlock.getArguments().take_back(consumer.getNumOutputs()))
|
|
mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType()));
|
|
// 7. All of producer's output operands except the one fused.
|
|
// TODO: allow fusion of multi-result producers.
|
|
assert(producer->getNumResults() == 1 && "expected single result producer");
|
|
|
|
// 8. Clone all producer operations except for the yield and index operations
|
|
// to the fused operation.
|
|
for (auto &op : producerBlock.without_terminator()) {
|
|
if (!isa<IndexOp>(op))
|
|
rewriter.clone(op, mapper);
|
|
}
|
|
// 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just
|
|
// forward the yield operand.
|
|
auto yieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator());
|
|
// TODO: allow fusion of multi-result producers.
|
|
assert(producer->getNumResults() == 1 && "expected single result producer");
|
|
unsigned producerResultNumber = 0;
|
|
Value replacement =
|
|
mapper.lookupOrDefault(yieldOp.getOperand(producerResultNumber));
|
|
// Sanity checks, if replacement is not already in the mapper then it must be
|
|
// produced outside.
|
|
if (replacement == yieldOp.getOperand(producerResultNumber)) {
|
|
if (auto bb = replacement.dyn_cast<BlockArgument>())
|
|
assert(bb.getOwner() != &producerBlock &&
|
|
"yielded block argument must have been mapped");
|
|
else
|
|
assert(!producer->isAncestor(replacement.getDefiningOp()) &&
|
|
"yielded value must have been mapped");
|
|
}
|
|
mapper.map(consumerBlock.getArgument(consumerOpOperand->getOperandNumber()),
|
|
replacement);
|
|
// 10. Clone operations from the consumer to the fused op.
|
|
for (auto &op : consumerBlock.getOperations())
|
|
rewriter.clone(op, mapper);
|
|
|
|
// Sanity checks.
|
|
assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() &&
|
|
"Ill-formed GenericOp region");
|
|
}
|
|
|
|
static Optional<SmallVector<Value>>
|
|
fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand,
|
|
const ControlElementwiseOpsFusionFn &controlFn,
|
|
PatternRewriter &rewriter) {
|
|
auto consumer = cast<GenericOp>(consumerOpOperand->getOwner());
|
|
if (!areElementwiseOpsFusable(producer, consumer, consumerOpOperand) ||
|
|
!controlFn(producer->getResult(0), *consumerOpOperand))
|
|
return llvm::None;
|
|
|
|
// TODO: allow fusing the producer of an output operand.
|
|
assert(consumer.isInputTensor(consumerOpOperand) &&
|
|
"expected producer of input operand");
|
|
|
|
// Compute the fused operands list and indexing maps.
|
|
SmallVector<Value> fusedOperands;
|
|
SmallVector<AffineMap> fusedIndexMaps;
|
|
fusedOperands.reserve(producer->getNumOperands() +
|
|
consumer->getNumOperands());
|
|
fusedIndexMaps.reserve(producer->getNumOperands() +
|
|
consumer->getNumOperands());
|
|
// In the following, numbering matches that of `generateFusedTensorOpRegion`.
|
|
// 3. Consumer input operands/maps up to consumerIdx (exclusive).
|
|
SmallVector<OpOperand *> consumerInputs = consumer.getInputOperands();
|
|
SmallVector<OpOperand *>::iterator it =
|
|
llvm::find(consumerInputs, consumerOpOperand);
|
|
assert(it != consumerInputs.end() && "expected to find the consumer operand");
|
|
for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
|
|
fusedOperands.push_back(opOperand->get());
|
|
fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
|
|
}
|
|
// 4. Splice in producer's input operands/maps.
|
|
assert(producer->getNumResults() == 1 && "expected single result producer");
|
|
AffineMap producerResultIndexMap =
|
|
producer.getTiedIndexingMap(producer.getOutputOperand(0));
|
|
for (OpOperand *opOperand : producer.getInputOperands()) {
|
|
fusedOperands.push_back(opOperand->get());
|
|
// Compute indexing maps for the producer args in the fused operation.
|
|
AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
|
|
opOperand, producerResultIndexMap,
|
|
consumer.getTiedIndexingMap(consumerOpOperand));
|
|
fusedIndexMaps.push_back(map);
|
|
}
|
|
// 4.b. Producer output operand/map that is fused needs to be passed if it is
|
|
// an "initTensor" (i.e. its value is actually read).
|
|
assert(producer->getNumResults() == 1 && "expected single result producer");
|
|
if (producer.isInitTensor(producer.getOutputOperand(0))) {
|
|
fusedOperands.push_back(producer.getOutputOperand(0)->get());
|
|
// Compute indexing maps for the producer args in the fused operation.
|
|
AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
|
|
producer.getOutputOperand(0), producerResultIndexMap,
|
|
consumer.getTiedIndexingMap(consumerOpOperand));
|
|
fusedIndexMaps.push_back(map);
|
|
}
|
|
// 5. Remaining consumer's input operands/maps (drop past index
|
|
// `consumerIdx`).
|
|
for (OpOperand *opOperand :
|
|
llvm::make_range(std::next(it), consumerInputs.end())) {
|
|
fusedOperands.push_back(opOperand->get());
|
|
fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
|
|
}
|
|
// 6. All of consumer's output operands (skip operands: added by the builder).
|
|
for (OpOperand *opOperand : consumer.getOutputOperands())
|
|
fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand));
|
|
// 7. All of producer's output operands/maps except the one fused.
|
|
// TODO: allow fusion of multi-result producers.
|
|
assert(producer->getNumResults() == 1 && "expected single result producer");
|
|
|
|
// Generate the fused op.
|
|
SmallVector<Value> consumerOutputs = consumer.getOutputOperands();
|
|
auto fusedOp = rewriter.create<GenericOp>(
|
|
consumer.getLoc(), consumer->getResultTypes(),
|
|
/*inputs=*/fusedOperands,
|
|
// TODO: handle outputs.
|
|
consumerOutputs, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
|
|
consumer.iterator_types(),
|
|
/*doc=*/nullptr,
|
|
/*library_call=*/nullptr);
|
|
|
|
// Construct an AffineMap from consumer loops to producer loops.
|
|
// consumer loop -> tensor index
|
|
AffineMap consumerResultIndexMap =
|
|
consumer.getTiedIndexingMap(consumerOpOperand);
|
|
// tensor index -> producer loop
|
|
AffineMap invProducerResultIndexMap =
|
|
inversePermutation(producerResultIndexMap);
|
|
assert(invProducerResultIndexMap &&
|
|
"expected producer result indexig map to be invertible");
|
|
// consumer loop -> producer loop
|
|
AffineMap consumerToProducerLoopsMap =
|
|
invProducerResultIndexMap.compose(consumerResultIndexMap);
|
|
|
|
generateFusedElementwiseOpRegion(rewriter, fusedOp,
|
|
consumerToProducerLoopsMap,
|
|
consumerOpOperand, consumer.getNumLoops());
|
|
return SmallVector<Value>(fusedOp->getResults());
|
|
}
|
|
|
|
/// Linearize the expressions in `sourceMap` based on the `reassociationMaps`
|
|
/// provided, given the shape of the source tensor that corresponds to the
|
|
/// `sourceMap`. Note that this implicitly assumes that the tensors dimensions
|
|
/// are "row-major" ordered logically.
|
|
///
|
|
/// For example:
|
|
///
|
|
/// %0 = op ... : tensor<?x?x4x5xf32>
|
|
/// with output index_map `affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>`
|
|
///
|
|
/// and reshape:
|
|
/// %1 = linalg.tensor_collapse_shape %0 [[0], [0, 1, 2]] :
|
|
/// tensor<?x?x4x5xf32> into tensor<?x?xf32>
|
|
///
|
|
/// would be rewritten into:
|
|
/// %0 = op ... : tensor<?x?x4x5xf32>
|
|
/// with output index_map
|
|
/// `affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>`
|
|
template <typename TensorReshapeOp>
|
|
static AffineMap linearizeCollapsedDims(AffineMap sourceMap,
|
|
TensorReshapeOp reshapeOp) {
|
|
constexpr bool isExpanding =
|
|
std::is_same<TensorReshapeOp, TensorExpandShapeOp>::value;
|
|
ArrayRef<int64_t> sourceShape =
|
|
(isExpanding ? reshapeOp.getResultType().getShape()
|
|
: reshapeOp.getSrcType().getShape());
|
|
SmallVector<AffineExpr> resultExprs;
|
|
ArrayRef<AffineExpr> sourceExprs = sourceMap.getResults();
|
|
MLIRContext *context = sourceMap.getContext();
|
|
|
|
// Compute the result exprs based on the reassociation maps.
|
|
for (auto &indices : reshapeOp.getReassociationIndices()) {
|
|
// Assume that they are in-order and contiguous (already checked in
|
|
// verifier).
|
|
assert(!indices.empty());
|
|
SmallVector<int64_t> sizes;
|
|
SmallVector<AffineExpr> dimExprs;
|
|
for (auto en : llvm::zip(sourceShape.slice(indices[0], indices.size()),
|
|
sourceExprs.slice(indices[0], indices.size()))) {
|
|
if (std::get<0>(en) == 1)
|
|
continue;
|
|
sizes.push_back(std::get<0>(en));
|
|
dimExprs.push_back(std::get<1>(en));
|
|
}
|
|
AffineExpr linearizedExpr =
|
|
makeCanonicalStridedLayoutExpr(sizes, dimExprs, context);
|
|
resultExprs.push_back(linearizedExpr);
|
|
}
|
|
return AffineMap::get(sourceMap.getNumDims(), sourceMap.getNumSymbols(),
|
|
resultExprs, context);
|
|
}
|
|
|
|
// TensorExpandShapeOp is fusable with its consumer (i.e. reshape as a
|
|
// producer). Fusing when operand has higher rank will require use of mods and
|
|
// divs in the indexing maps of the fused op which would make it non-invertible.
|
|
static bool isTensorReshapeOpFoldableByLinearization(
|
|
TensorExpandShapeOp expandOp, AffineMap useIndexMap, bool asProducer) {
|
|
if (!asProducer)
|
|
return false;
|
|
return useIndexMap.isPermutation();
|
|
}
|
|
|
|
// TensorCollapseShapeOp is fusable with its producer (i.e. reshape as a
|
|
// consumer).
|
|
static bool isTensorReshapeOpFoldableByLinearization(
|
|
TensorCollapseShapeOp collapseOp, AffineMap useIndexMap, bool asProducer) {
|
|
if (asProducer)
|
|
return false;
|
|
return useIndexMap.isPermutation();
|
|
}
|
|
|
|
/// Check if the reshape operation is only expansion into/collapsing of
|
|
/// unit-dimension.
|
|
template <typename TensorReshapeOp>
|
|
static bool isUnitDimExpansionOnly(TensorReshapeOp reshapeOp) {
|
|
constexpr bool isExpanding =
|
|
std::is_same<TensorReshapeOp, TensorExpandShapeOp>::value;
|
|
ArrayRef<int64_t> expandedShape =
|
|
(isExpanding ? reshapeOp.getResultType().getShape()
|
|
: reshapeOp.getSrcType().getShape());
|
|
for (auto &indices : reshapeOp.getReassociationIndices()) {
|
|
unsigned numUnitDims = 0;
|
|
for (int64_t position : indices)
|
|
if (expandedShape[position] == 1)
|
|
numUnitDims++;
|
|
if (numUnitDims != indices.size() - 1)
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
/// Conditions for folding a generic operation with a reshape op by expanding
|
|
/// the iteration space dimensionality for tensor operations. These are
|
|
/// preconditions assumed by `foldReshapeByDimExpansion` which implements the
|
|
/// following fusion pattern.
|
|
///
|
|
/// Consider
|
|
///
|
|
/// %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>)
|
|
/// indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
|
|
/// affine_map<(d0, d1, d2) -> (d1, d2)>,
|
|
/// affine_map<(d0, d1, d2) -> (d0, d2, d1)>]
|
|
/// %d = linalg.tensor_expand_shape %c [[0, 1], [2], [3, 4, 5]]
|
|
/// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
|
|
///
|
|
/// The reshape can be folded into the `genericOp` if its loop dimensionality
|
|
/// is increased to match the result (operand) of the tensor_expand_shape.
|
|
/// The indexing_map of the fused tensor in the `genericOp` and the
|
|
/// reassociation map helps compute the indexing maps of the modified op.
|
|
/// For the above example, based on the reassociation map it
|
|
/// can be concluded that
|
|
///
|
|
/// - The loop used to access the first dimension of the fused tensor is split
|
|
/// into two.
|
|
/// - The loop used to access the second dimension of the fused tensor is kept
|
|
/// as is.
|
|
/// - The loop used to access the third dimension of the fused tensor is split
|
|
/// into three.
|
|
///
|
|
/// i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified
|
|
/// op, then
|
|
///
|
|
/// d0 -> e0, e1
|
|
/// d1 -> e2, e3, e4
|
|
/// d2 -> e5
|
|
///
|
|
/// substituting this, the generic op can be rewritten as
|
|
///
|
|
/// %d = linalg.generic ins(%0, %1 : )
|
|
/// indexing_maps =
|
|
/// [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>,
|
|
/// affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>,
|
|
/// affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>]
|
|
///
|
|
/// Since operands to the linalg generic are now 5D, reshapes can be introduced
|
|
/// to make it consistent
|
|
///
|
|
/// %0 = linalg.tensor_expand_shape %a [[0, 1, 2], [3, 4], [5]]
|
|
/// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
|
|
/// %1 = linalg.tensor_expand_shape %b [[0, 1, 2], [3]]
|
|
/// : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
|
|
///
|
|
/// The added reshapes are again expanding patterns, so they will get fused
|
|
/// with its producers if possible.
|
|
static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
|
|
OpOperand *fusableOpOperand) {
|
|
// Is fusable only if:
|
|
// - All the indexing maps for operands and results are projected
|
|
// permutations.
|
|
// - The fused tensor is not a scalar.
|
|
// - All the loops are parallel loops.
|
|
return genericOp.hasTensorSemantics() &&
|
|
llvm::all_of(genericOp.indexing_maps().getValue(),
|
|
[](Attribute attr) {
|
|
return attr.cast<AffineMapAttr>()
|
|
.getValue()
|
|
.isProjectedPermutation();
|
|
}) &&
|
|
genericOp.getTiedIndexingMap(fusableOpOperand).getNumResults() > 0 &&
|
|
llvm::all_of(genericOp.iterator_types(), [](Attribute attr) {
|
|
return attr.cast<StringAttr>().getValue() ==
|
|
getParallelIteratorTypeName();
|
|
});
|
|
}
|
|
|
|
namespace {
|
|
/// Information needed to expand a generic operation to fold the reshape with
|
|
/// it.
|
|
class ExpansionInfo {
|
|
public:
|
|
// Computes the mapping from original dimensions of the op to the dimensions
|
|
// of the expanded op given the `indexingMap` of the fused operand/result of
|
|
// the generic op, the `reassocationMaps` of the reshape op and the shape of
|
|
// the expanded op.
|
|
LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
|
|
ArrayRef<AffineMap> reassociationMaps,
|
|
ArrayRef<int64_t> expandedShape,
|
|
PatternRewriter &rewriter);
|
|
unsigned getOrigOpNumDims() const { return reassociation.size(); }
|
|
unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
|
|
ReassociationIndicesRef getExpandedDims(unsigned i) const {
|
|
return reassociation[i];
|
|
}
|
|
ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const {
|
|
return expandedShapeMap[i];
|
|
}
|
|
|
|
private:
|
|
/// Reassociation from the dimensions in the original operation to the
|
|
/// dimension of the expanded operation.
|
|
SmallVector<ReassociationIndices> reassociation;
|
|
/// Mapping from extent of loops in the original operation, to the extent of
|
|
/// loops in the expanded operation.
|
|
SmallVector<SmallVector<int64_t>> expandedShapeMap;
|
|
unsigned expandedOpNumDims;
|
|
};
|
|
} // namespace
|
|
|
|
LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
|
|
OpOperand *fusableOpOperand,
|
|
ArrayRef<AffineMap> reassociationMaps,
|
|
ArrayRef<int64_t> expandedShape,
|
|
PatternRewriter &rewriter) {
|
|
if (reassociationMaps.empty())
|
|
return failure();
|
|
AffineMap fusedIndexMap = linalgOp.getTiedIndexingMap(fusableOpOperand);
|
|
|
|
Optional<SmallVector<int64_t, 4>> originalLoopRange =
|
|
linalgOp.getStaticLoopRanges();
|
|
if (!originalLoopRange)
|
|
return rewriter.notifyMatchFailure(linalgOp, "unable to find loop range");
|
|
|
|
reassociation.clear();
|
|
expandedShapeMap.clear();
|
|
// Compute the number of dimension in the expanded op that correspond to each
|
|
// dimension of the original op.
|
|
SmallVector<unsigned> numExpandedDims(fusedIndexMap.getNumDims(), 1);
|
|
expandedShapeMap.resize(fusedIndexMap.getNumDims());
|
|
for (auto resultExpr : llvm::enumerate(fusedIndexMap.getResults())) {
|
|
unsigned pos = resultExpr.value().cast<AffineDimExpr>().getPosition();
|
|
AffineMap foldedDims = reassociationMaps[resultExpr.index()];
|
|
numExpandedDims[pos] = foldedDims.getNumResults();
|
|
ArrayRef<int64_t> shape =
|
|
expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]);
|
|
expandedShapeMap[pos].assign(shape.begin(), shape.end());
|
|
}
|
|
// The remaining dimensions remain the same.
|
|
for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims()))
|
|
if (expandedShapeMap[i].empty())
|
|
expandedShapeMap[i] = {(*originalLoopRange)[i]};
|
|
|
|
// Compute reassociation map from the original op to the expanded op.
|
|
unsigned sum = 0;
|
|
reassociation.reserve(fusedIndexMap.getNumDims());
|
|
for (auto numFoldedDim : llvm::enumerate(numExpandedDims)) {
|
|
auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value());
|
|
reassociation.emplace_back(seq.begin(), seq.end());
|
|
sum += numFoldedDim.value();
|
|
}
|
|
expandedOpNumDims = sum;
|
|
return success();
|
|
}
|
|
|
|
/// Epanding the body of a linalg operation requires adaptations of the accessed
|
|
/// loop indices. Specifically, access of indices in the original operation need
|
|
/// to be replaced with linearizations of indices in the expanded op. That
|
|
/// requires the shape of the expanded dimensions to be static (at least all but
|
|
/// the most significant). For now check that these are all statically sized.
|
|
/// Note that this could be extended to handle dynamic case, but the
|
|
/// implementation below uses `affine.apply` which seems to have issues when the
|
|
/// shapes are not static.
|
|
LogicalResult isGenericOpExpandable(GenericOp genericOp,
|
|
const ExpansionInfo &expansionInfo,
|
|
PatternRewriter &rewriter) {
|
|
if (!genericOp.hasIndexSemantics())
|
|
return success();
|
|
for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
|
|
ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
|
|
if (expandedShape.size() == 1)
|
|
continue;
|
|
for (int64_t shape : expandedShape.drop_front()) {
|
|
if (ShapedType::isDynamic(shape)) {
|
|
return rewriter.notifyMatchFailure(
|
|
genericOp, "cannot expand due to index semantics and dynamic dims");
|
|
}
|
|
}
|
|
}
|
|
return success();
|
|
}
|
|
|
|
/// Return the indexing map to use in the expanded op for a given the
|
|
/// `indexingMap` of the original operation.
|
|
static AffineMap
|
|
getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap,
|
|
const ExpansionInfo &expansionInfo) {
|
|
SmallVector<AffineExpr> newExprs;
|
|
for (AffineExpr expr : indexingMap.getResults()) {
|
|
unsigned pos = expr.cast<AffineDimExpr>().getPosition();
|
|
SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>(
|
|
llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) {
|
|
return builder.getAffineDimExpr(static_cast<unsigned>(v));
|
|
}));
|
|
newExprs.append(expandedExprs.begin(), expandedExprs.end());
|
|
}
|
|
return AffineMap::get(expansionInfo.getExpandedOpNumDims(),
|
|
indexingMap.getNumSymbols(), newExprs,
|
|
builder.getContext());
|
|
}
|
|
|
|
/// Return the type of the operand/result to use in the expanded op given the
|
|
/// type in the original op.
|
|
static RankedTensorType getExpandedType(RankedTensorType originalType,
|
|
AffineMap indexingMap,
|
|
const ExpansionInfo &expansionInfo) {
|
|
SmallVector<int64_t> expandedShape;
|
|
for (AffineExpr expr : indexingMap.getResults()) {
|
|
unsigned dim = expr.cast<AffineDimExpr>().getPosition();
|
|
auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim);
|
|
expandedShape.append(dimExpansion.begin(), dimExpansion.end());
|
|
}
|
|
return RankedTensorType::get(expandedShape, originalType.getElementType());
|
|
}
|
|
|
|
/// Returns the reassociation maps to use in the `linalg.tensor_expand_shape`
|
|
/// operation to convert the operands of the original operation to operands of
|
|
/// the expanded operation. The same method is used to compute the
|
|
/// `linalg.tensor_collapse_shape` used to collapse the result of the expanded
|
|
/// op to get the value that can replace all uses of the results of the original
|
|
/// op.
|
|
static SmallVector<ReassociationIndices>
|
|
getReassociationForExpansion(AffineMap indexingMap,
|
|
const ExpansionInfo &expansionInfo) {
|
|
SmallVector<ReassociationIndices> reassociation;
|
|
unsigned numReshapeDims = 0;
|
|
for (AffineExpr expr : indexingMap.getResults()) {
|
|
unsigned dim = expr.cast<AffineDimExpr>().getPosition();
|
|
auto numExpandedDims = expansionInfo.getExpandedDims(dim).size();
|
|
SmallVector<int64_t, 2> indices = llvm::to_vector<2>(
|
|
llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims));
|
|
reassociation.emplace_back(std::move(indices));
|
|
numReshapeDims += numExpandedDims;
|
|
}
|
|
return reassociation;
|
|
}
|
|
|
|
/// Update the body of an expanded linalg operation having index semantics. The
|
|
/// indices of the original operation need to be recovered by linearizing the
|
|
/// indices of the correspoding dimensions of the expanded operation. For now it
|
|
/// is assumed that the shapes of the expanded operation needed for
|
|
/// linearization are static.
|
|
static void updateExpandedGenericOpRegion(PatternRewriter &rewriter,
|
|
Location loc, Region &fusedRegion,
|
|
const ExpansionInfo &expansionInfo) {
|
|
// Replace the original indices by the linearization of the expanded indices.
|
|
for (IndexOp indexOp :
|
|
llvm::make_early_inc_range(fusedRegion.front().getOps<IndexOp>())) {
|
|
ArrayRef<int64_t> expandedDims =
|
|
expansionInfo.getExpandedDims(indexOp.dim());
|
|
assert(!expandedDims.empty() && "expected valid expansion info");
|
|
|
|
// Skip index operations that are not affected by the expansion.
|
|
if (expandedDims.size() == 1 &&
|
|
expandedDims.front() == (int64_t)indexOp.dim())
|
|
continue;
|
|
|
|
// Linearize the expanded indices of the original index dimension.
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
rewriter.setInsertionPointAfter(indexOp);
|
|
ArrayRef<int64_t> expandedDimsShape =
|
|
expansionInfo.getExpandedShapeOfDim(indexOp.dim()).drop_front();
|
|
SmallVector<Value> expandedIndices;
|
|
expandedIndices.reserve(expandedDims.size() - 1);
|
|
llvm::transform(
|
|
expandedDims.drop_front(), std::back_inserter(expandedIndices),
|
|
[&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); });
|
|
Value newIndex = rewriter.create<IndexOp>(loc, expandedDims.front());
|
|
for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) {
|
|
assert(!ShapedType::isDynamic(std::get<0>(it)));
|
|
AffineExpr idx, acc;
|
|
bindDims(rewriter.getContext(), idx, acc);
|
|
newIndex = rewriter.create<AffineApplyOp>(
|
|
indexOp.getLoc(), idx + acc * std::get<0>(it),
|
|
ValueRange{std::get<1>(it), newIndex});
|
|
}
|
|
rewriter.replaceOp(indexOp, newIndex);
|
|
}
|
|
}
|
|
|
|
/// Implements the fusion of a tensor_collapse_shape or a tensor_expand_shape op
|
|
/// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes
|
|
/// that those conditions have been satisfied.
|
|
static Optional<SmallVector<Value>>
|
|
fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
|
|
OpOperand *fusableOpOperand,
|
|
PatternRewriter &rewriter) {
|
|
assert(isFusableWithReshapeByDimExpansion(genericOp, fusableOpOperand) &&
|
|
"preconditions for fuse operation failed");
|
|
// Check if reshape is expanding or collapsing.
|
|
auto expandingReshapeOp = dyn_cast<TensorExpandShapeOp>(*reshapeOp);
|
|
auto collapsingReshapeOp = dyn_cast<TensorCollapseShapeOp>(*reshapeOp);
|
|
bool isExpanding = (expandingReshapeOp != nullptr);
|
|
RankedTensorType expandedType = isExpanding
|
|
? expandingReshapeOp.getResultType()
|
|
: collapsingReshapeOp.getSrcType();
|
|
|
|
ExpansionInfo expansionInfo;
|
|
if (failed(expansionInfo.compute(
|
|
genericOp, fusableOpOperand,
|
|
isExpanding ? expandingReshapeOp.getReassociationMaps()
|
|
: collapsingReshapeOp.getReassociationMaps(),
|
|
expandedType.getShape(), rewriter)))
|
|
return llvm::None;
|
|
|
|
if (failed(isGenericOpExpandable(genericOp, expansionInfo, rewriter)))
|
|
return llvm::None;
|
|
|
|
SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
|
|
llvm::map_range(genericOp.getIndexingMaps(), [&](AffineMap m) {
|
|
return getIndexingMapInExpandedOp(rewriter, m, expansionInfo);
|
|
}));
|
|
|
|
SmallVector<Value> expandedOpOperands;
|
|
expandedOpOperands.reserve(genericOp.getNumInputs());
|
|
for (OpOperand *opOperand : genericOp.getInputOperands()) {
|
|
if (opOperand == fusableOpOperand) {
|
|
expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.src()
|
|
: collapsingReshapeOp.src());
|
|
continue;
|
|
}
|
|
if (genericOp.isInputTensor(opOperand)) {
|
|
AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
|
|
RankedTensorType expandedOperandType =
|
|
getExpandedType(opOperand->get().getType().cast<RankedTensorType>(),
|
|
indexingMap, expansionInfo);
|
|
if (expandedOperandType != opOperand->get().getType()) {
|
|
// Reshape the operand to get the right type.
|
|
SmallVector<ReassociationIndices> reassociation =
|
|
getReassociationForExpansion(indexingMap, expansionInfo);
|
|
expandedOpOperands.push_back(rewriter.create<TensorExpandShapeOp>(
|
|
genericOp.getLoc(), expandedOperandType, opOperand->get(),
|
|
reassociation));
|
|
continue;
|
|
}
|
|
}
|
|
expandedOpOperands.push_back(opOperand->get());
|
|
}
|
|
|
|
Location loc = genericOp.getLoc();
|
|
SmallVector<Value> outputs;
|
|
for (OpOperand *opOperand : genericOp.getOutputOperands()) {
|
|
AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
|
|
RankedTensorType expandedOutputType =
|
|
getExpandedType(opOperand->get().getType().cast<RankedTensorType>(),
|
|
indexingMap, expansionInfo);
|
|
if (expandedOutputType != opOperand->get().getType()) {
|
|
SmallVector<ReassociationIndices> reassociation =
|
|
getReassociationForExpansion(indexingMap, expansionInfo);
|
|
outputs.push_back(rewriter.create<TensorExpandShapeOp>(
|
|
genericOp.getLoc(), expandedOutputType, opOperand->get(),
|
|
reassociation));
|
|
}
|
|
}
|
|
|
|
// The iterator types of the expanded op are all parallel.
|
|
SmallVector<StringRef> iteratorTypes(expansionInfo.getExpandedOpNumDims(),
|
|
getParallelIteratorTypeName());
|
|
|
|
TypeRange resultTypes = ValueRange(outputs).getTypes();
|
|
auto fusedOp =
|
|
rewriter.create<GenericOp>(genericOp.getLoc(), resultTypes,
|
|
/*inputs=*/expandedOpOperands, outputs,
|
|
expandedOpIndexingMaps, iteratorTypes);
|
|
Region &fusedRegion = fusedOp->getRegion(0);
|
|
Region &originalRegion = genericOp->getRegion(0);
|
|
rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin());
|
|
|
|
// Update the index accesses after the expansion.
|
|
updateExpandedGenericOpRegion(rewriter, loc, fusedRegion, expansionInfo);
|
|
|
|
// Reshape the result values to their original shape if this is a collapsing
|
|
// reshape folded into its consumer.
|
|
SmallVector<Value> resultVals;
|
|
for (OpResult opResult : genericOp->getOpResults()) {
|
|
int64_t resultNumber = opResult.getResultNumber();
|
|
if (!isExpanding && resultTypes[resultNumber] != opResult.getType()) {
|
|
SmallVector<ReassociationIndices> reassociation =
|
|
getReassociationForExpansion(
|
|
genericOp.getTiedIndexingMap(
|
|
genericOp.getOutputOperand(resultNumber)),
|
|
expansionInfo);
|
|
resultVals.push_back(rewriter.create<TensorCollapseShapeOp>(
|
|
genericOp.getLoc(), opResult.getType(),
|
|
fusedOp->getResult(resultNumber), reassociation));
|
|
} else {
|
|
resultVals.push_back(fusedOp->getResult(resultNumber));
|
|
}
|
|
}
|
|
// Assuming a single result.
|
|
return resultVals;
|
|
}
|
|
|
|
namespace {
|
|
|
|
/// Pattern to fold tensor_expand_shape op with its consumer by using the source
|
|
/// of the reshape op as the operand in the consumer (instead of the result of
|
|
/// the tensor_collapse_shape). The corresponding index map in the consumer
|
|
/// needs to be modified to linearize the folded dimension.
|
|
///
|
|
/// For example,
|
|
///
|
|
/// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
|
/// %0 = linalg.tensor_expand_shape %arg0 [[0], [1, 2], [3]]
|
|
/// tensor<?x?x?xf32> into tensor<?x?x4x?xf32>
|
|
/// %1 = linalg.generic { indexing_maps = [#map0, #map0, #map0], ... }
|
|
/// ins(%0, %arg1 : tensor<?x?x4x?xf32>, tensor<?x?x4x?xf32>) ...
|
|
/// -> tensor<?x?x4x?xf32>
|
|
///
|
|
/// can be folded into
|
|
///
|
|
/// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
|
|
/// #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
|
/// %0 = linalg.generic { indexing_maps = [#map0, #map1, #map1] ... }
|
|
/// ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x4x?xf32>) ...
|
|
/// -> tensor<?x?x4x?xf32>
|
|
template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp>
|
|
struct FoldProducerReshapeOpByLinearization
|
|
: public OpRewritePattern<GenericOp> {
|
|
using OpRewritePattern<GenericOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(GenericOp genericOp,
|
|
PatternRewriter &rewriter) const override {
|
|
if (!genericOp.hasTensorSemantics())
|
|
return failure();
|
|
SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
|
|
for (auto en : llvm::enumerate(inputOperands)) {
|
|
auto reshapeOp = en.value()->get().getDefiningOp<TensorReshapeOp>();
|
|
if (!reshapeOp)
|
|
continue;
|
|
|
|
if (!isTensorReshapeOpFoldableByLinearization(
|
|
reshapeOp, genericOp.getTiedIndexingMap(en.value()),
|
|
/*asProducer =*/true) ||
|
|
(foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp)))
|
|
continue;
|
|
|
|
// Compute the fused operands list,
|
|
SmallVector<Value> fusedOperands = genericOp.getInputOperands();
|
|
fusedOperands[en.index()] = reshapeOp.src();
|
|
SmallVector<Value> outputOperands = genericOp.getOutputOperands();
|
|
llvm::append_range(fusedOperands, outputOperands);
|
|
|
|
// Compute indexing_maps for the fused operation. The indexing_maps for
|
|
// the operands of the consumers that arent fused are the same.
|
|
SmallVector<AffineMap> fusedIndexMaps = genericOp.getIndexingMaps();
|
|
|
|
// Accepted consumer maps are either identity or permutation.
|
|
auto invMap = inversePermutation(fusedIndexMaps[en.index()]);
|
|
|
|
// Compute the indexing map to use for the result of the producer.
|
|
AffineMap modifiedMap = linearizeCollapsedDims(invMap, reshapeOp);
|
|
// The modified map cannot have symbols.
|
|
if (modifiedMap.getNumSymbols())
|
|
return failure();
|
|
for (AffineExpr expr : modifiedMap.getResults()) {
|
|
if (!expr.isPureAffine())
|
|
return failure();
|
|
}
|
|
fusedIndexMaps[en.index()] = modifiedMap;
|
|
|
|
// Further check that the resulting index maps can be fused and
|
|
// inverted. Without this the resultant op is not legal.
|
|
if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
|
|
return rewriter.notifyMatchFailure(
|
|
genericOp, "fused op loop bound computation failed");
|
|
}
|
|
|
|
rewriter.startRootUpdate(genericOp);
|
|
genericOp->setOperands(fusedOperands);
|
|
genericOp.indexing_mapsAttr(
|
|
rewriter.getAffineMapArrayAttr(fusedIndexMaps));
|
|
rewriter.finalizeRootUpdate(genericOp);
|
|
return success();
|
|
}
|
|
return failure();
|
|
}
|
|
};
|
|
|
|
static SmallVector<ReassociationIndices>
|
|
getReassociationIndices(ArrayRef<AffineMap> maps) {
|
|
SmallVector<ReassociationIndices> reassociation;
|
|
for (AffineMap map : maps) {
|
|
ReassociationIndices indices;
|
|
for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
|
|
unsigned pos = map.getResult(i).cast<AffineDimExpr>().getPosition();
|
|
indices.push_back(pos);
|
|
}
|
|
reassociation.push_back(indices);
|
|
}
|
|
return reassociation;
|
|
}
|
|
|
|
/// Pattern to move rank reducing reshape after an elementwise linalg generic
|
|
/// op. This is useful to expose more fusion opportunities between named ops and
|
|
/// generic ops. This can only be done if there is no broadcast or permuation
|
|
/// within the dimensions we need to merge.
|
|
///
|
|
/// For example,
|
|
///
|
|
/// %0 = linalg.tensor_expand_shape %A [[0, 1], [2]]
|
|
/// : tensor<12544x16xf32> into tensor<112x112x16xf32>
|
|
/// %2 = linalg.generic {indexing_maps = [
|
|
/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
|
|
/// affine_map<(d0, d1, d2) -> (d2)>,
|
|
/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types =
|
|
/// ["parallel", "parallel", "parallel"]} {
|
|
/// } -> tensor<112x112x16xf32>
|
|
///
|
|
/// into
|
|
///
|
|
/// %2 = linalg.generic {indexing_maps = [
|
|
/// affine_map<(d0, d1) -> (d0, d1)>,
|
|
/// affine_map<(d0, d1) -> (d1)>,
|
|
/// affine_map<(d0, d1) -> (d0, d1)>],
|
|
/// iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1
|
|
/// : tensor<12544x16xf32>, tensor<16xf32>) outs(%1 : tensor<12544x16xf32>) {
|
|
/// } -> tensor<12544x16xf32>
|
|
/// %3 = linalg.tensor_expand_shape %2 [[0, 1], [2]]
|
|
/// : tensor<12544x16xf32> into tensor<112x112x16xf32>
|
|
struct PushExpandingReshape : public OpRewritePattern<GenericOp> {
|
|
using OpRewritePattern<GenericOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(GenericOp genericOp,
|
|
PatternRewriter &rewriter) const override {
|
|
// Only apply to elementwise linalg on tensor.
|
|
if (!genericOp.hasTensorSemantics() ||
|
|
genericOp.getNumParallelLoops() != genericOp.getNumLoops())
|
|
return failure();
|
|
// Only support identity output maps. It could be extended to permuations if
|
|
// needed.
|
|
if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *opOperand) {
|
|
return !genericOp.getTiedIndexingMap(opOperand).isIdentity();
|
|
}))
|
|
return failure();
|
|
int64_t destRank = genericOp.getNumParallelLoops();
|
|
SmallVector<Value> newOperands = genericOp.getInputOperands();
|
|
TensorExpandShapeOp reshapeFound;
|
|
// 1. Look for tensor_expand_shape operands and figure out save the
|
|
// dimensions merged.
|
|
SmallVector<OpOperand *> inputOperands = genericOp.getInputOperands();
|
|
for (auto en : llvm::enumerate(inputOperands)) {
|
|
auto reshapeOp =
|
|
en.value()->get().template getDefiningOp<TensorExpandShapeOp>();
|
|
if (!reshapeOp)
|
|
continue;
|
|
// TODO: We could support non-identity map as long as the merged
|
|
// dimensions are still contiguous.
|
|
if (!genericOp.getTiedIndexingMap(en.value()).isIdentity())
|
|
continue;
|
|
if (reshapeFound) {
|
|
// Only support a second reshape op if it has the same reassociate maps.
|
|
if (reshapeFound.getReassociationMaps() ==
|
|
reshapeOp.getReassociationMaps())
|
|
newOperands[en.index()] = reshapeOp.src();
|
|
continue;
|
|
}
|
|
reshapeFound = reshapeOp;
|
|
newOperands[en.index()] = reshapeOp.src();
|
|
}
|
|
if (!reshapeFound)
|
|
return failure();
|
|
|
|
// Calculate the reassociation indices and rassociated reverse map.
|
|
SmallVector<ReassociationIndices> reassociation =
|
|
getReassociationIndices(reshapeFound.getReassociationMaps());
|
|
SmallVector<unsigned> remap(destRank);
|
|
for (auto &indices : llvm::enumerate(reassociation)) {
|
|
for (int64_t index : indices.value()) {
|
|
remap[index] = indices.index();
|
|
}
|
|
}
|
|
// 2. Verify that we can merge the dimensions in the linalg and that we
|
|
// don't need to create new reshapes operands. Inserting new reshape
|
|
// operands would defeat the purpose of the transformation.
|
|
for (auto en : llvm::enumerate(inputOperands)) {
|
|
if (en.value()->get() == newOperands[en.index()]) {
|
|
AffineMap map = genericOp.getTiedIndexingMap(en.value());
|
|
for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) {
|
|
if (reassociation[remap[map.getDimPosition(i)]].size() > 1)
|
|
return failure();
|
|
}
|
|
}
|
|
}
|
|
|
|
// 3. Calculate the affine map remapping and the reassociation to apply to
|
|
// output tensors.
|
|
SmallVector<AffineMap> newMaps;
|
|
unsigned newRank = reassociation.size();
|
|
for (auto map : genericOp.getIndexingMaps()) {
|
|
SmallVector<AffineExpr> newExprs;
|
|
for (auto expr : map.getResults()) {
|
|
unsigned position = expr.template cast<AffineDimExpr>().getPosition();
|
|
// Skip dimension merged except for the last of the group.
|
|
if (reassociation[remap[position]].back() == position) {
|
|
newExprs.push_back(
|
|
getAffineDimExpr(remap[position], genericOp.getContext()));
|
|
}
|
|
}
|
|
newMaps.push_back(
|
|
AffineMap::get(newRank, 0, newExprs, genericOp.getContext()));
|
|
}
|
|
|
|
// 4. Reshape the output tensors.
|
|
SmallVector<Value> newOutputs;
|
|
SmallVector<Type> newOutputTypes;
|
|
for (auto output : genericOp.outputs()) {
|
|
auto newOutputType = RankedTensorType::get(
|
|
reshapeFound.getSrcType().getShape(),
|
|
output.getType().template cast<RankedTensorType>().getElementType());
|
|
Value newOutput = rewriter.create<TensorCollapseShapeOp>(
|
|
genericOp->getLoc(), newOutputType, output, reassociation);
|
|
newOutputTypes.push_back(newOutputType);
|
|
newOutputs.push_back(newOutput);
|
|
}
|
|
// 5. Create a new generic op with lowerer rank.
|
|
SmallVector<StringRef> iteratorTypes(newRank,
|
|
getParallelIteratorTypeName());
|
|
auto newOp = rewriter.create<GenericOp>(genericOp->getLoc(), newOutputTypes,
|
|
newOperands, newOutputs, newMaps,
|
|
iteratorTypes);
|
|
rewriter.inlineRegionBefore(genericOp.region(), newOp.region(),
|
|
newOp.region().begin());
|
|
// 6. Reshape the so that the type matches the uses.
|
|
SmallVector<Value> newResults;
|
|
for (auto result : llvm::enumerate(newOp->getResults())) {
|
|
newResults.push_back(rewriter.create<TensorExpandShapeOp>(
|
|
genericOp->getLoc(), genericOp.getOutputTensorTypes()[result.index()],
|
|
result.value(), reassociation));
|
|
}
|
|
rewriter.replaceOp(genericOp, newResults);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Pattern to fuse a tensor_collapse_shape op with its consumer generic op,
|
|
/// when the reshape op is collapsing dimensions. The dimensionality of the loop
|
|
/// in the consumer is expanded.
|
|
class FoldWithProducerReshapeOpByExpansion
|
|
: public OpRewritePattern<GenericOp> {
|
|
public:
|
|
FoldWithProducerReshapeOpByExpansion(
|
|
MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes,
|
|
PatternBenefit benefit = 1)
|
|
: OpRewritePattern<GenericOp>(context, benefit),
|
|
controlFoldingReshapes(foldReshapes) {}
|
|
|
|
LogicalResult matchAndRewrite(GenericOp genericOp,
|
|
PatternRewriter &rewriter) const override {
|
|
for (OpOperand *opOperand : genericOp.getInputTensorOperands()) {
|
|
TensorCollapseShapeOp reshapeOp =
|
|
opOperand->get().getDefiningOp<TensorCollapseShapeOp>();
|
|
if (!reshapeOp)
|
|
continue;
|
|
// Fold only if
|
|
// - The tensor reshape op is folding.
|
|
// - All constraints of fusing with reshape by expansion are met.
|
|
if (!isFusableWithReshapeByDimExpansion(genericOp, opOperand) ||
|
|
(!controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)))
|
|
continue;
|
|
|
|
Optional<SmallVector<Value>> replacementValues =
|
|
fuseWithReshapeByExpansion(genericOp, reshapeOp, opOperand, rewriter);
|
|
if (!replacementValues)
|
|
return failure();
|
|
rewriter.replaceOp(genericOp, replacementValues.getValue());
|
|
return success();
|
|
}
|
|
return failure();
|
|
}
|
|
|
|
private:
|
|
ControlElementwiseOpsFusionFn controlFoldingReshapes;
|
|
};
|
|
|
|
/// Pattern to fold tensor_collapse_shape or tensor_expand_shape op with its
|
|
/// producer. The corresponding index map in the consumer needs to be modified
|
|
/// to linearize the folded dimension.
|
|
template <bool foldUnitDimReshapesOnly, typename TensorReshapeOp>
|
|
struct FoldConsumerReshapeOpByLinearization
|
|
: public OpRewritePattern<TensorReshapeOp> {
|
|
using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
|
|
PatternRewriter &rewriter) const override {
|
|
GenericOp producer = reshapeOp.src().template getDefiningOp<GenericOp>();
|
|
if (!producer || !producer.hasTensorSemantics() ||
|
|
producer.getNumOutputs() != 1 ||
|
|
!isTensorReshapeOpFoldableByLinearization(
|
|
reshapeOp,
|
|
producer.getTiedIndexingMap(producer.getOutputOperand(0)),
|
|
/*asProducer =*/false) ||
|
|
(foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp)))
|
|
return failure();
|
|
// The indexing_maps for the operands of the fused operation are same as
|
|
// those for the operands of the producer.
|
|
SmallVector<AffineMap> fusedIndexMaps = producer.getIndexingMaps();
|
|
|
|
auto invMap = inversePermutation(
|
|
producer.getTiedIndexingMap(producer.getOutputOperand(0)));
|
|
|
|
// Compute the indexing map to use for the operand of the producer.
|
|
AffineMap modifiedMap = linearizeCollapsedDims(invMap, reshapeOp);
|
|
for (AffineExpr expr : modifiedMap.getResults()) {
|
|
if (!expr.isPureAffine()) {
|
|
return rewriter.notifyMatchFailure(
|
|
producer, "fused op indexing map is not affine");
|
|
}
|
|
}
|
|
fusedIndexMaps.back() = modifiedMap;
|
|
|
|
// Further check that the resulting index maps can be fused and
|
|
// inverted. Without this the resultant op is not legal.
|
|
if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
|
|
return rewriter.notifyMatchFailure(
|
|
producer, "fused op loop bound computation failed");
|
|
}
|
|
|
|
Location loc = producer.getLoc();
|
|
SmallVector<Value> inputOperands = producer.getInputOperands();
|
|
Value output = rewriter.create<TensorReshapeOp>(
|
|
loc, producer.getOutputOperand(0)->get(),
|
|
reshapeOp.getReassociationExprs());
|
|
auto fusedOp = rewriter.create<GenericOp>(
|
|
loc, reshapeOp.getResultType(),
|
|
/*inputs=*/inputOperands,
|
|
// TODO: handle outputs.
|
|
/*outputs=*/output, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
|
|
producer.iterator_types(),
|
|
/*doc=*/nullptr,
|
|
/*library_call=*/nullptr);
|
|
auto &fusedRegion = fusedOp->getRegion(0);
|
|
rewriter.cloneRegionBefore(producer->getRegion(0), fusedRegion,
|
|
fusedRegion.begin());
|
|
rewriter.replaceOp(reshapeOp, fusedOp->getResults());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Pattern to fold a tensor_expand_shape op with its producer generic op
|
|
/// by expanding the dimensionality of the loop in the producer op.
|
|
struct FoldReshapeWithGenericOpByExpansion
|
|
: public OpRewritePattern<TensorExpandShapeOp> {
|
|
|
|
FoldReshapeWithGenericOpByExpansion(
|
|
MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes,
|
|
PatternBenefit benefit = 1)
|
|
: OpRewritePattern<TensorExpandShapeOp>(context, benefit),
|
|
controlFoldingReshapes(foldReshapes) {}
|
|
|
|
LogicalResult matchAndRewrite(TensorExpandShapeOp reshapeOp,
|
|
PatternRewriter &rewriter) const override {
|
|
// Fold only if all constraints of fusing with reshape by expansion are met.
|
|
GenericOp producer = reshapeOp.src().getDefiningOp<GenericOp>();
|
|
if (!producer || producer.getNumOutputs() != 1 ||
|
|
!isFusableWithReshapeByDimExpansion(producer,
|
|
producer.getOutputOperand(0)) ||
|
|
!controlFoldingReshapes(producer->getResult(0),
|
|
reshapeOp->getOpOperand(0)))
|
|
return failure();
|
|
Optional<SmallVector<Value>> replacementValues = fuseWithReshapeByExpansion(
|
|
producer, reshapeOp, producer.getOutputOperand(0), rewriter);
|
|
if (!replacementValues)
|
|
return failure();
|
|
rewriter.replaceOp(reshapeOp, replacementValues.getValue());
|
|
return success();
|
|
}
|
|
|
|
private:
|
|
ControlElementwiseOpsFusionFn controlFoldingReshapes;
|
|
};
|
|
|
|
/// Pattern to fold a generic op with a splat constant/scalar constant. Does not
|
|
/// handle cases where the constant is not single-valued.
|
|
class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
|
|
public:
|
|
FoldScalarOrSplatConstant(MLIRContext *context,
|
|
ControlElementwiseOpsFusionFn &fun,
|
|
PatternBenefit benefit = 1)
|
|
: OpRewritePattern<GenericOp>(context, benefit), controlFn(fun) {}
|
|
|
|
LogicalResult matchAndRewrite(GenericOp genericOp,
|
|
PatternRewriter &rewriter) const override {
|
|
if (!genericOp.hasTensorSemantics())
|
|
return failure();
|
|
for (OpOperand *opOperand : genericOp.getInputOperands()) {
|
|
Operation *def = opOperand->get().getDefiningOp();
|
|
Attribute constantAttr;
|
|
auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool {
|
|
{
|
|
DenseElementsAttr splatAttr;
|
|
if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) &&
|
|
splatAttr.isSplat() &&
|
|
splatAttr.getType().getElementType().isIntOrFloat()) {
|
|
constantAttr = splatAttr.getSplatValue();
|
|
return true;
|
|
}
|
|
}
|
|
{
|
|
IntegerAttr intAttr;
|
|
if (matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) {
|
|
constantAttr = intAttr;
|
|
return true;
|
|
}
|
|
}
|
|
{
|
|
FloatAttr floatAttr;
|
|
if (matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) {
|
|
constantAttr = floatAttr;
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
};
|
|
|
|
auto resultValue = opOperand->get().dyn_cast<OpResult>();
|
|
if (!def || !resultValue || !isScalarOrSplatConstantOp(def) ||
|
|
!controlFn(resultValue, *opOperand))
|
|
continue;
|
|
|
|
// The operands and the indexing_maps of the fused operation the same as
|
|
// the operands and indexing_maps of the generic operations with the
|
|
// values at the constant index dropped.
|
|
SmallVector<AffineMap> fusedIndexMaps;
|
|
SmallVector<Value> fusedOperands;
|
|
SmallVector<Location> fusedLocs{genericOp.getLoc()};
|
|
fusedIndexMaps.reserve(genericOp.getNumInputsAndOutputs());
|
|
fusedOperands.reserve(genericOp.getNumInputs());
|
|
fusedLocs.reserve(fusedLocs.size() + genericOp.getNumInputs());
|
|
for (OpOperand *inputOperand : genericOp.getInputOperands()) {
|
|
if (inputOperand == opOperand)
|
|
continue;
|
|
Value inputValue = inputOperand->get();
|
|
fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(inputOperand));
|
|
fusedOperands.push_back(inputValue);
|
|
fusedLocs.push_back(inputValue.getLoc());
|
|
}
|
|
for (OpOperand *outputOperand : genericOp.getOutputOperands())
|
|
fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(outputOperand));
|
|
|
|
// Check if the operation shapes to loops map is computable.
|
|
if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
|
|
return rewriter.notifyMatchFailure(
|
|
genericOp, "fused op loop bound computation failed");
|
|
}
|
|
|
|
// Create a constant scalar value from the splat constant.
|
|
Value scalarConstant = rewriter.create<arith::ConstantOp>(
|
|
def->getLoc(), constantAttr, constantAttr.getType());
|
|
|
|
SmallVector<Value> outputOperands = genericOp.getOutputOperands();
|
|
auto fusedOp = rewriter.create<GenericOp>(
|
|
rewriter.getFusedLoc(fusedLocs), genericOp->getResultTypes(),
|
|
/*inputs=*/fusedOperands,
|
|
/*outputs=*/outputOperands,
|
|
rewriter.getAffineMapArrayAttr(fusedIndexMaps),
|
|
genericOp.iterator_types(),
|
|
/*doc=*/nullptr,
|
|
/*library_call=*/nullptr);
|
|
|
|
// Map the block argument corresponding to the replaced argument with the
|
|
// scalar constant.
|
|
Region ®ion = genericOp->getRegion(0);
|
|
Block &entryBlock = *region.begin();
|
|
BlockAndValueMapping mapping;
|
|
mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()),
|
|
scalarConstant);
|
|
Region &fusedRegion = fusedOp->getRegion(0);
|
|
rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(),
|
|
mapping);
|
|
rewriter.replaceOp(genericOp, fusedOp->getResults());
|
|
return success();
|
|
}
|
|
return failure();
|
|
}
|
|
|
|
private:
|
|
ControlElementwiseOpsFusionFn controlFn;
|
|
};
|
|
|
|
/// Base class for constant folding linalg.generic ops with N inputs, 1 output,
|
|
/// and permutation indexing maps.
|
|
///
|
|
/// `ConcreteType` should provide methods with signatures
|
|
///
|
|
/// ```c++
|
|
/// bool matchIndexingMaps(GenericOp genericOp) const;
|
|
/// RegionComputationFn getRegionComputeFn(GenericOp) const;
|
|
/// ```
|
|
///
|
|
/// The latter inspects the region and returns the computation inside as a
|
|
/// functor. The functor will be invoked with constant elements for all inputs
|
|
/// and should return the corresponding computea constant element for output.
|
|
template <typename ConcreteType>
|
|
class FoldConstantBase : public OpRewritePattern<GenericOp> {
|
|
public:
|
|
struct APIntOrFloat {
|
|
Optional<APInt> apInt;
|
|
Optional<APFloat> apFloat;
|
|
};
|
|
struct APIntOrFloatArray {
|
|
SmallVector<APInt> apInts;
|
|
SmallVector<APFloat> apFloats;
|
|
};
|
|
using RegionComputationFn =
|
|
std::function<APIntOrFloat(const APIntOrFloatArray &)>;
|
|
|
|
FoldConstantBase(MLIRContext *context,
|
|
const ControlElementwiseOpsFusionFn &controlFn,
|
|
PatternBenefit benefit = 1)
|
|
: OpRewritePattern<GenericOp>(context, benefit), controlFn(controlFn) {}
|
|
|
|
LogicalResult matchAndRewrite(GenericOp genericOp,
|
|
PatternRewriter &rewriter) const override {
|
|
if (genericOp.hasBufferSemantics())
|
|
return failure();
|
|
|
|
// Only support ops generating one output for now.
|
|
if (genericOp.getNumOutputs() != 1)
|
|
return failure();
|
|
|
|
auto outputType = genericOp.getResultTypes().front().dyn_cast<ShapedType>();
|
|
// Require the output types to be static give we are generating constants.
|
|
if (!outputType || !outputType.hasStaticShape())
|
|
return failure();
|
|
|
|
if (!llvm::all_of(genericOp.getInputOperands(), [](OpOperand *operand) {
|
|
return operand->get().getType().isa<ShapedType>();
|
|
}))
|
|
return failure();
|
|
|
|
// Make sure all element types are the same.
|
|
auto getOperandElementType = [](OpOperand *operand) {
|
|
return operand->get().getType().cast<ShapedType>().getElementType();
|
|
};
|
|
if (!llvm::is_splat(llvm::map_range(genericOp.getInputAndOutputOperands(),
|
|
getOperandElementType)))
|
|
return failure();
|
|
|
|
// We can only handle the case where we have int/float elements.
|
|
auto elementType = outputType.getElementType();
|
|
if (!elementType.isIntOrFloat())
|
|
return failure();
|
|
|
|
// Require all indexing maps to be permutations for now. This is common and
|
|
// it simplifies input/output access greatly: we can do the data shuffling
|
|
// entirely in the compiler, without needing to turn all indices into
|
|
// Values, and then do affine apply on them, and then match back the
|
|
// constant again.
|
|
if (!llvm::all_of(genericOp.getIndexingMaps(),
|
|
[](AffineMap map) { return map.isPermutation(); }))
|
|
return failure();
|
|
|
|
for (OpOperand *operand : genericOp.getOutputOperands()) {
|
|
if (genericOp.payloadUsesValueFromOperand(operand))
|
|
return failure();
|
|
}
|
|
|
|
// Further check the indexing maps are okay for the ConcreteType.
|
|
if (!static_cast<const ConcreteType *>(this)->matchIndexingMaps(genericOp))
|
|
return failure();
|
|
|
|
// Defer to the concrete type to check the region and discover the
|
|
// computation inside.
|
|
RegionComputationFn computeFn =
|
|
static_cast<const ConcreteType *>(this)->getRegionComputeFn(genericOp);
|
|
if (!computeFn)
|
|
return failure();
|
|
|
|
// All inputs should be constants.
|
|
int numInputs = genericOp.getNumInputs();
|
|
SmallVector<DenseIntOrFPElementsAttr> inputValues(numInputs);
|
|
for (auto operand : llvm::enumerate(genericOp.getInputOperands())) {
|
|
if (!matchPattern(operand.value()->get(),
|
|
m_Constant(&inputValues[operand.index()])))
|
|
return failure();
|
|
}
|
|
|
|
// Identified this as a potential candidate for folding. Now check the
|
|
// policy to see whether we are allowed to proceed.
|
|
for (int i = 0; i < numInputs; ++i) {
|
|
OpOperand *consumer = genericOp.getInputOperand(i);
|
|
OpResult producer = consumer->get().cast<OpResult>();
|
|
if (!controlFn(producer, *consumer))
|
|
return failure();
|
|
}
|
|
|
|
auto linalgOp = cast<LinalgOp>(genericOp.getOperation());
|
|
SmallVector<int64_t, 4> loopBounds = linalgOp.computeStaticLoopSizes();
|
|
int64_t numElements = outputType.getNumElements();
|
|
|
|
// Use APInt/APFloat instead of Attribute here for constructing the output.
|
|
// This helps to avoid blowing up compiler memory usage: Attributes would
|
|
// unify the following cases but they have lifetime as the MLIRContext.
|
|
SmallVector<APInt> intOutputValues;
|
|
SmallVector<APFloat> fpOutputValues;
|
|
if (elementType.template isa<FloatType>())
|
|
fpOutputValues.resize(numElements, APFloat(0.f));
|
|
else
|
|
intOutputValues.resize(numElements);
|
|
|
|
// Return the constant dim positions from the given permutation map.
|
|
auto getDimPositions = [](AffineMap map) {
|
|
SmallVector<unsigned> dims;
|
|
dims.reserve(map.getNumResults());
|
|
for (AffineExpr result : map.getResults()) {
|
|
dims.push_back(result.cast<AffineDimExpr>().getPosition());
|
|
}
|
|
return dims;
|
|
};
|
|
|
|
SmallVector<SmallVector<unsigned>> inputDims;
|
|
for (int i = 0; i < numInputs; ++i)
|
|
inputDims.push_back(getDimPositions(genericOp.getIndexingMaps()[i]));
|
|
auto outputDims = getDimPositions(genericOp.getIndexingMaps().back());
|
|
auto outputShape = outputType.getShape();
|
|
|
|
// Allocate small vectors for index delinearization. Initial values do not
|
|
// matter here as they will be overwritten later.
|
|
SmallVector<uint64_t> indices(loopBounds.size(), 0);
|
|
SmallVector<uint64_t> dstIndices(loopBounds.size(), 0);
|
|
SmallVector<SmallVector<uint64_t>> srcIndices(
|
|
numInputs, SmallVector<uint64_t>(loopBounds.size(), 0));
|
|
SmallVector<uint64_t> srcLinearIndices(numInputs, 0);
|
|
uint64_t dstLinearIndex = 0;
|
|
|
|
// Allocate spaces for compute function inputs. Initial values do not matter
|
|
// here as they will be overwritten later.
|
|
APIntOrFloatArray computeFnInputs;
|
|
|
|
auto inputShapes = llvm::to_vector<4>(
|
|
llvm::map_range(genericOp.getInputOperands(), [](OpOperand *operand) {
|
|
return operand->get().getType().cast<ShapedType>().getShape();
|
|
}));
|
|
|
|
// Given a `linearIndex`, remap it to a linear index to access linalg op
|
|
// inputs/ouputs. This mutates `indices`, `srcIndices`, `dstIndices`,
|
|
// `srcLinearIndices`, `dstLinearIndex` in place.
|
|
auto computeRemappedLinearIndex = [&](int linearIndex) {
|
|
int totalCount = linearIndex;
|
|
for (int dim = loopBounds.size() - 1; dim >= 0; --dim) {
|
|
indices[dim] = totalCount % loopBounds[dim];
|
|
totalCount /= loopBounds[dim];
|
|
}
|
|
|
|
for (int dim = loopBounds.size() - 1; dim >= 0; --dim) {
|
|
for (int i = 0; i < numInputs; ++i)
|
|
srcIndices[i][dim] = indices[inputDims[i][dim]];
|
|
dstIndices[dim] = indices[outputDims[dim]];
|
|
}
|
|
|
|
dstLinearIndex = dstIndices.front();
|
|
for (int i = 0; i < numInputs; ++i)
|
|
srcLinearIndices[i] = srcIndices[i].front();
|
|
|
|
for (int dim = 1; dim < outputType.getRank(); ++dim) {
|
|
dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim];
|
|
for (int i = 0; i < numInputs; ++i)
|
|
srcLinearIndices[i] =
|
|
srcLinearIndices[i] * inputShapes[i][dim] + srcIndices[i][dim];
|
|
}
|
|
};
|
|
|
|
bool isFloat = elementType.isa<FloatType>();
|
|
if (isFloat) {
|
|
SmallVector<iterator_range<DenseElementsAttr::FloatElementIterator>>
|
|
inputFpIterators;
|
|
for (int i = 0; i < numInputs; ++i)
|
|
inputFpIterators.push_back(inputValues[i].getValues<APFloat>());
|
|
|
|
computeFnInputs.apFloats.resize(numInputs, APFloat(0.f));
|
|
|
|
// Transpose the input constant. Because we don't know its rank in
|
|
// advance, we need to loop over the range [0, element count) and
|
|
// delinearize the index.
|
|
for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) {
|
|
computeRemappedLinearIndex(linearIndex);
|
|
|
|
// Collect constant elements for all inputs at this loop iteration.
|
|
for (int i = 0; i < numInputs; ++i) {
|
|
computeFnInputs.apFloats[i] =
|
|
*(inputFpIterators[i].begin() + srcLinearIndices[i]);
|
|
}
|
|
|
|
// Invoke the computation to get the corresponding constant output
|
|
// element.
|
|
APIntOrFloat outputs = computeFn(computeFnInputs);
|
|
|
|
fpOutputValues[dstLinearIndex] = outputs.apFloat.getValue();
|
|
}
|
|
} else {
|
|
SmallVector<iterator_range<DenseElementsAttr::IntElementIterator>>
|
|
inputIntIterators;
|
|
for (int i = 0; i < numInputs; ++i)
|
|
inputIntIterators.push_back(inputValues[i].getValues<APInt>());
|
|
|
|
computeFnInputs.apInts.resize(numInputs);
|
|
|
|
// Transpose the input constant. Because we don't know its rank in
|
|
// advance, we need to loop over the range [0, element count) and
|
|
// delinearize the index.
|
|
for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) {
|
|
computeRemappedLinearIndex(linearIndex);
|
|
|
|
// Collect constant elements for all inputs at this loop iteration.
|
|
for (int i = 0; i < numInputs; ++i) {
|
|
computeFnInputs.apInts[i] =
|
|
*(inputIntIterators[i].begin() + srcLinearIndices[i]);
|
|
}
|
|
|
|
// Invoke the computation to get the corresponding constant output
|
|
// element.
|
|
APIntOrFloat outputs = computeFn(computeFnInputs);
|
|
|
|
intOutputValues[dstLinearIndex] = outputs.apInt.getValue();
|
|
}
|
|
}
|
|
|
|
DenseIntOrFPElementsAttr outputAttr;
|
|
if (isFloat) {
|
|
outputAttr = DenseFPElementsAttr::get(outputType, fpOutputValues);
|
|
} else {
|
|
outputAttr = DenseIntElementsAttr::get(outputType, intOutputValues);
|
|
}
|
|
rewriter.replaceOpWithNewOp<ConstantOp>(genericOp, outputAttr);
|
|
return success();
|
|
}
|
|
|
|
private:
|
|
ControlElementwiseOpsFusionFn controlFn;
|
|
};
|
|
|
|
// Folds linalg.generic ops that are actually transposes on constant values.
|
|
struct FoldConstantTranspose : public FoldConstantBase<FoldConstantTranspose> {
|
|
using FoldConstantBase::FoldConstantBase;
|
|
|
|
bool matchIndexingMaps(GenericOp genericOp) const {
|
|
// We should have one input and one output.
|
|
return genericOp.getIndexingMaps().size() == 2;
|
|
}
|
|
|
|
RegionComputationFn getRegionComputeFn(GenericOp genericOp) const {
|
|
// Make sure the region only contains a yield op.
|
|
Block &body = genericOp.region().front();
|
|
if (!llvm::hasSingleElement(body))
|
|
return nullptr;
|
|
auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
|
|
if (!yieldOp)
|
|
return nullptr;
|
|
|
|
// The yield op should return the block argument corresponds to the input.
|
|
for (Value yieldVal : yieldOp.values()) {
|
|
auto yieldArg = yieldVal.dyn_cast<BlockArgument>();
|
|
if (!yieldArg || yieldArg.getOwner() != &body)
|
|
return nullptr;
|
|
if (yieldArg.getArgNumber() != 0)
|
|
return nullptr;
|
|
}
|
|
|
|
// No computation; just return the orginal value.
|
|
return [](const APIntOrFloatArray &inputs) {
|
|
if (inputs.apFloats.empty())
|
|
return APIntOrFloat{inputs.apInts.front(), llvm::None};
|
|
return APIntOrFloat{llvm::None, inputs.apFloats.front()};
|
|
};
|
|
}
|
|
|
|
ControlElementwiseOpsFusionFn controlFn;
|
|
};
|
|
|
|
} // namespace
|
|
|
|
static Optional<SmallVector<Value>>
|
|
fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand,
|
|
GenericOp producer,
|
|
const ControlElementwiseOpsFusionFn &controlFn) {
|
|
if (producer->getNumResults() != 1)
|
|
return llvm::None;
|
|
|
|
return fuseElementwiseOpsImpl(producer, consumerOpOperand, controlFn,
|
|
rewriter);
|
|
}
|
|
|
|
bool mlir::linalg::skipUnitDimReshape(const OpResult &producer,
|
|
OpOperand &consumer) {
|
|
if (auto producerCollapseOp =
|
|
dyn_cast<linalg::TensorCollapseShapeOp>(producer.getOwner())) {
|
|
return !isUnitDimExpansionOnly(producerCollapseOp);
|
|
}
|
|
if (auto consumerExpandOp =
|
|
dyn_cast<linalg::TensorExpandShapeOp>(consumer.getOwner())) {
|
|
return !isUnitDimExpansionOnly(consumerExpandOp);
|
|
}
|
|
return true;
|
|
}
|
|
|
|
namespace {
|
|
/// Patterns to fuse a generic op, with the producer of its operands.
|
|
class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
|
|
public:
|
|
FuseElementwiseOps(MLIRContext *context, ControlElementwiseOpsFusionFn &fun,
|
|
PatternBenefit benefit = 1)
|
|
: OpRewritePattern<GenericOp>(context, benefit), controlFn(fun) {}
|
|
|
|
LogicalResult matchAndRewrite(GenericOp genericOp,
|
|
PatternRewriter &rewriter) const override {
|
|
// Find the first operand that is defined by another generic op on tensors.
|
|
for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
|
|
auto producer =
|
|
dyn_cast_or_null<GenericOp>(opOperand->get().getDefiningOp());
|
|
if (!producer || !producer.hasTensorSemantics())
|
|
continue;
|
|
Optional<SmallVector<Value>> fusedOpResults =
|
|
fuseElementwiseOps(rewriter, opOperand, producer, controlFn);
|
|
if (fusedOpResults) {
|
|
rewriter.replaceOp(genericOp, *fusedOpResults);
|
|
return success();
|
|
}
|
|
}
|
|
return failure();
|
|
}
|
|
|
|
private:
|
|
ControlElementwiseOpsFusionFn controlFn;
|
|
};
|
|
|
|
/// Pass that fuses generic ops on tensors. Used only for testing.
|
|
struct LinalgElementwiseOpFusionPass
|
|
: public LinalgElementwiseOpFusionBase<LinalgElementwiseOpFusionPass> {
|
|
void runOnOperation() override {
|
|
Operation *op = getOperation();
|
|
RewritePatternSet patterns(op->getContext());
|
|
ControlElementwiseOpsFusionFn allowFoldingFn =
|
|
[](const OpResult &producer, const OpOperand &consumer) {
|
|
return true;
|
|
};
|
|
populateElementwiseOpsFusionPatterns(
|
|
patterns,
|
|
LinalgElementwiseFusionOptions().setControlFoldingReshapes(
|
|
allowFoldingUnitDimReshapes ? allowFoldingFn : skipUnitDimReshape));
|
|
|
|
// Use TopDownTraversal for compile time reasons
|
|
GreedyRewriteConfig grc;
|
|
grc.useTopDownTraversal = true;
|
|
(void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns),
|
|
grc);
|
|
}
|
|
};
|
|
|
|
/// Pass to test folding of reshape ops with generic ops by linearization.
|
|
struct FoldReshapeOpsByLinearizationPass
|
|
: public LinalgFoldReshapeOpsByLinearizationBase<
|
|
FoldReshapeOpsByLinearizationPass> {
|
|
void runOnOperation() override {
|
|
Operation *op = getOperation();
|
|
RewritePatternSet patterns(op->getContext());
|
|
populateFoldReshapeOpsByLinearizationPatterns(patterns);
|
|
if (allowFoldingUnitDimReshapes) {
|
|
populateFoldUnitDimsReshapeOpsByLinearizationPatterns(patterns);
|
|
}
|
|
(void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
|
|
}
|
|
};
|
|
|
|
/// Forces `outs` operands of linalg operations to use `linalg.init_tensor` if
|
|
/// the value of the `outs` operand is not used within the op. This is only
|
|
/// implemented for `linalg.generic` operations for now, but should hold for all
|
|
/// linalg structured ops.
|
|
struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
|
|
using OpRewritePattern<GenericOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(GenericOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
rewriter.startRootUpdate(op);
|
|
bool modifiedOutput = false;
|
|
Location loc = op.getLoc();
|
|
for (OpOperand *opOperand : op.getOutputOperands()) {
|
|
if (!op.payloadUsesValueFromOperand(opOperand)) {
|
|
Value operandVal = opOperand->get();
|
|
auto operandType = operandVal.getType().dyn_cast<RankedTensorType>();
|
|
if (!operandType)
|
|
continue;
|
|
|
|
// If outs is already an `init_tensor` operation, nothing to do.
|
|
auto definingOp = operandVal.getDefiningOp<InitTensorOp>();
|
|
if (definingOp)
|
|
continue;
|
|
modifiedOutput = true;
|
|
SmallVector<Value> dynamicDims;
|
|
for (auto dim : llvm::enumerate(operandType.getShape())) {
|
|
if (dim.value() != ShapedType::kDynamicSize)
|
|
continue;
|
|
dynamicDims.push_back(rewriter.createOrFold<tensor::DimOp>(
|
|
loc, operandVal, dim.index()));
|
|
}
|
|
Value initTensor = rewriter.create<InitTensorOp>(
|
|
loc, dynamicDims, operandType.getShape(),
|
|
operandType.getElementType());
|
|
op->setOperand(opOperand->getOperandNumber(), initTensor);
|
|
}
|
|
}
|
|
if (!modifiedOutput) {
|
|
rewriter.cancelRootUpdate(op);
|
|
return failure();
|
|
}
|
|
rewriter.finalizeRootUpdate(op);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void mlir::linalg::populateFoldReshapeOpsByLinearizationPatterns(
|
|
RewritePatternSet &patterns) {
|
|
patterns
|
|
.add<FoldProducerReshapeOpByLinearization<false, TensorCollapseShapeOp>,
|
|
FoldProducerReshapeOpByLinearization<false, TensorExpandShapeOp>,
|
|
FoldConsumerReshapeOpByLinearization<false, TensorCollapseShapeOp>,
|
|
FoldConsumerReshapeOpByLinearization<false, TensorExpandShapeOp>>(
|
|
patterns.getContext());
|
|
}
|
|
|
|
void mlir::linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
|
|
RewritePatternSet &patterns) {
|
|
patterns
|
|
.add<FoldProducerReshapeOpByLinearization<true, TensorCollapseShapeOp>,
|
|
FoldProducerReshapeOpByLinearization<true, TensorExpandShapeOp>,
|
|
FoldConsumerReshapeOpByLinearization<true, TensorCollapseShapeOp>,
|
|
FoldConsumerReshapeOpByLinearization<true, TensorExpandShapeOp>>(
|
|
patterns.getContext());
|
|
}
|
|
|
|
void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
|
|
RewritePatternSet &patterns,
|
|
ControlElementwiseOpsFusionFn controlFoldingReshapes) {
|
|
patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(),
|
|
controlFoldingReshapes);
|
|
patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
|
|
controlFoldingReshapes);
|
|
}
|
|
|
|
void mlir::linalg::populateElementwiseOpsFusionPatterns(
|
|
RewritePatternSet &patterns, LinalgElementwiseFusionOptions options) {
|
|
auto *context = patterns.getContext();
|
|
patterns.add<FuseElementwiseOps, FoldScalarOrSplatConstant,
|
|
FoldConstantTranspose>(context,
|
|
options.controlElementwiseOpsFusionFn);
|
|
patterns.add<RemoveOutsDependency>(context);
|
|
populateFoldReshapeOpsByExpansionPatterns(patterns,
|
|
options.controlFoldingReshapesFn);
|
|
AffineApplyOp::getCanonicalizationPatterns(patterns, context);
|
|
GenericOp::getCanonicalizationPatterns(patterns, context);
|
|
TensorExpandShapeOp::getCanonicalizationPatterns(patterns, context);
|
|
TensorCollapseShapeOp::getCanonicalizationPatterns(patterns, context);
|
|
context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(
|
|
patterns);
|
|
}
|
|
|
|
void mlir::linalg::populatePushReshapeOpsPatterns(RewritePatternSet &patterns) {
|
|
auto *context = patterns.getContext();
|
|
patterns.add<PushExpandingReshape>(context);
|
|
}
|
|
|
|
std::unique_ptr<Pass> mlir::createLinalgElementwiseOpFusionPass() {
|
|
return std::make_unique<LinalgElementwiseOpFusionPass>();
|
|
}
|
|
|
|
std::unique_ptr<Pass> mlir::createFoldReshapeOpsByLinearizationPass() {
|
|
return std::make_unique<FoldReshapeOpsByLinearizationPass>();
|
|
}
|