[mlir] Use indices instead of affine maps when composing 2 reshape ops.

https://llvm.discourse.group/t/rfc-reshape-ops-restructuring/3310

Differential Revision: https://reviews.llvm.org/D105550
This commit is contained in:
Alexander Belyaev 2021-07-07 15:09:59 +02:00
parent 033de11150
commit d659527829
2 changed files with 35 additions and 35 deletions

View File

@ -28,27 +28,22 @@ using ReassociationExprs = SmallVector<AffineExpr, 2>;
/// Attribute name for the ArrayAttr which encodes reassociation indices.
constexpr StringRef getReassociationAttrName();
/// Collapse reassociation maps that are used in pair of reshape ops where one
/// Compose reassociation maps that are used in pair of reshape ops where one
/// is a producer and other is the consumer. Only valid to use this method when
/// both the producer and consumer are collapsing dimensions or both are
/// expanding dimensions.
///
/// For example,
/// mapsProducer = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>,
/// affine_map<(d0, d1, d2, d3, d4) -> (d2)>,
/// affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>]
/// mapsConsumer = [affine_map<(d0, d1, d2) -> (d0, d1)>,
/// affine_map<(d0, d1, d2) -> (d2)>]
/// producerReassociation = [[0, 1], [2], [3, 4]]
/// consumerReassociation = [[0, 1], [2]]
///
/// is folded into
///
/// result = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
/// affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>]
/// TODO: Use reassociation indices instead of affine maps here.
Optional<SmallVector<ReassociationIndices>>
collapseReassociationIndices(ArrayRef<AffineMap> mapsProducer,
ArrayRef<AffineMap> mapsConsumer,
MLIRContext *context);
/// result = [[0, 1, 2], [3, 4]].
Optional<SmallVector<ReassociationIndices>> composeReassociationIndices(
ArrayRef<ReassociationIndices> producerReassociations,
ArrayRef<ReassociationIndices> consumerReassociations,
MLIRContext *context);
/// Return the reassociations maps to use to reshape given the source type and
/// the target type when possible. Return llvm::None when this computation
@ -210,8 +205,8 @@ struct CollapseReshapeOps : public OpRewritePattern<ReshapeOpTy> {
ShapedType resultType = reshapeOp.getResultType();
Optional<SmallVector<ReassociationIndices>> reassociationIndices =
collapseReassociationIndices(srcReshapeOp.getReassociationMaps(),
reshapeOp.getReassociationMaps(),
composeReassociationIndices(srcReshapeOp.getReassociationIndices(),
reshapeOp.getReassociationIndices(),
rewriter.getContext());
if (!reassociationIndices)
return failure();

View File

@ -11,6 +11,8 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include <numeric>
using namespace mlir;
constexpr StringRef mlir::getReassociationAttrName() { return "reassociation"; }
@ -145,37 +147,40 @@ ParseResult mlir::parseReshapeLikeOp(OpAsmParser &parser,
return success();
}
Optional<SmallVector<ReassociationIndices>>
mlir::collapseReassociationIndices(ArrayRef<AffineMap> mapsProducer,
ArrayRef<AffineMap> mapsConsumer,
MLIRContext *context) {
Optional<SmallVector<ReassociationIndices>> mlir::composeReassociationIndices(
ArrayRef<ReassociationIndices> producerReassociations,
ArrayRef<ReassociationIndices> consumerReassociations,
MLIRContext *context) {
SmallVector<ReassociationIndices> composedIndices;
// Make the producer the larger sized vector. If they are of same size, the
// resulting reshape is not a supported reshape op.
if (mapsProducer.size() == mapsConsumer.size())
if (producerReassociations.size() == consumerReassociations.size())
return llvm::None;
if (mapsProducer.size() < mapsConsumer.size())
std::swap(mapsProducer, mapsConsumer);
if (producerReassociations.size() < consumerReassociations.size())
std::swap(producerReassociations, consumerReassociations);
// Handle the corner case of the result being a rank 0 shaped type. Return an
// empty reassociation.
if (mapsConsumer.empty())
return SmallVector<ReassociationIndices>{};
if (mapsProducer.size() != mapsConsumer[0].getNumDims())
if (consumerReassociations.empty())
return composedIndices;
size_t consumerDims = std::accumulate(
consumerReassociations.begin(), consumerReassociations.end(), 0,
[](size_t all, ReassociationIndicesRef indices) {
return all + indices.size();
});
if (producerReassociations.size() != consumerDims)
return llvm::None;
unsigned currDim = 0;
SmallVector<ReassociationIndices> reassociationMaps;
for (AffineMap rhs : mapsConsumer) {
for (ReassociationIndicesRef consumerIndices : consumerReassociations) {
ReassociationIndices reassociations;
for (AffineExpr rhsExpr : rhs.getResults()) {
AffineDimExpr dimExpr = rhsExpr.cast<AffineDimExpr>();
for (int i = 0, e = mapsProducer[dimExpr.getPosition()].getNumResults();
i < e; ++i)
reassociations.push_back(currDim++);
for (int64_t consumerIndex : consumerIndices) {
for (int64_t producerIndex : producerReassociations[consumerIndex])
reassociations.push_back(producerIndex);
}
reassociationMaps.push_back(std::move(reassociations));
composedIndices.push_back(std::move(reassociations));
}
return reassociationMaps;
return composedIndices;
}
bool mlir::isReassociationValid(ArrayRef<AffineMap> reassociation,