forked from OSchip/llvm-project
[mlir] Use indices instead of affine maps when composing 2 reshape ops.
https://llvm.discourse.group/t/rfc-reshape-ops-restructuring/3310 Differential Revision: https://reviews.llvm.org/D105550
This commit is contained in:
parent
033de11150
commit
d659527829
|
@ -28,27 +28,22 @@ using ReassociationExprs = SmallVector<AffineExpr, 2>;
|
|||
/// Attribute name for the ArrayAttr which encodes reassociation indices.
|
||||
constexpr StringRef getReassociationAttrName();
|
||||
|
||||
/// Collapse reassociation maps that are used in pair of reshape ops where one
|
||||
/// Compose reassociation maps that are used in pair of reshape ops where one
|
||||
/// is a producer and other is the consumer. Only valid to use this method when
|
||||
/// both the producer and consumer are collapsing dimensions or both are
|
||||
/// expanding dimensions.
|
||||
///
|
||||
/// For example,
|
||||
/// mapsProducer = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>,
|
||||
/// affine_map<(d0, d1, d2, d3, d4) -> (d2)>,
|
||||
/// affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>]
|
||||
/// mapsConsumer = [affine_map<(d0, d1, d2) -> (d0, d1)>,
|
||||
/// affine_map<(d0, d1, d2) -> (d2)>]
|
||||
/// producerReassociation = [[0, 1], [2], [3, 4]]
|
||||
/// consumerReassociation = [[0, 1], [2]]
|
||||
///
|
||||
/// is folded into
|
||||
///
|
||||
/// result = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
|
||||
/// affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>]
|
||||
/// TODO: Use reassociation indices instead of affine maps here.
|
||||
Optional<SmallVector<ReassociationIndices>>
|
||||
collapseReassociationIndices(ArrayRef<AffineMap> mapsProducer,
|
||||
ArrayRef<AffineMap> mapsConsumer,
|
||||
MLIRContext *context);
|
||||
/// result = [[0, 1, 2], [3, 4]].
|
||||
Optional<SmallVector<ReassociationIndices>> composeReassociationIndices(
|
||||
ArrayRef<ReassociationIndices> producerReassociations,
|
||||
ArrayRef<ReassociationIndices> consumerReassociations,
|
||||
MLIRContext *context);
|
||||
|
||||
/// Return the reassociations maps to use to reshape given the source type and
|
||||
/// the target type when possible. Return llvm::None when this computation
|
||||
|
@ -210,8 +205,8 @@ struct CollapseReshapeOps : public OpRewritePattern<ReshapeOpTy> {
|
|||
|
||||
ShapedType resultType = reshapeOp.getResultType();
|
||||
Optional<SmallVector<ReassociationIndices>> reassociationIndices =
|
||||
collapseReassociationIndices(srcReshapeOp.getReassociationMaps(),
|
||||
reshapeOp.getReassociationMaps(),
|
||||
composeReassociationIndices(srcReshapeOp.getReassociationIndices(),
|
||||
reshapeOp.getReassociationIndices(),
|
||||
rewriter.getContext());
|
||||
if (!reassociationIndices)
|
||||
return failure();
|
||||
|
|
|
@ -11,6 +11,8 @@
|
|||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
|
||||
#include <numeric>
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
constexpr StringRef mlir::getReassociationAttrName() { return "reassociation"; }
|
||||
|
@ -145,37 +147,40 @@ ParseResult mlir::parseReshapeLikeOp(OpAsmParser &parser,
|
|||
return success();
|
||||
}
|
||||
|
||||
Optional<SmallVector<ReassociationIndices>>
|
||||
mlir::collapseReassociationIndices(ArrayRef<AffineMap> mapsProducer,
|
||||
ArrayRef<AffineMap> mapsConsumer,
|
||||
MLIRContext *context) {
|
||||
Optional<SmallVector<ReassociationIndices>> mlir::composeReassociationIndices(
|
||||
ArrayRef<ReassociationIndices> producerReassociations,
|
||||
ArrayRef<ReassociationIndices> consumerReassociations,
|
||||
MLIRContext *context) {
|
||||
SmallVector<ReassociationIndices> composedIndices;
|
||||
// Make the producer the larger sized vector. If they are of same size, the
|
||||
// resulting reshape is not a supported reshape op.
|
||||
if (mapsProducer.size() == mapsConsumer.size())
|
||||
if (producerReassociations.size() == consumerReassociations.size())
|
||||
return llvm::None;
|
||||
if (mapsProducer.size() < mapsConsumer.size())
|
||||
std::swap(mapsProducer, mapsConsumer);
|
||||
if (producerReassociations.size() < consumerReassociations.size())
|
||||
std::swap(producerReassociations, consumerReassociations);
|
||||
|
||||
// Handle the corner case of the result being a rank 0 shaped type. Return an
|
||||
// empty reassociation.
|
||||
if (mapsConsumer.empty())
|
||||
return SmallVector<ReassociationIndices>{};
|
||||
if (mapsProducer.size() != mapsConsumer[0].getNumDims())
|
||||
if (consumerReassociations.empty())
|
||||
return composedIndices;
|
||||
|
||||
size_t consumerDims = std::accumulate(
|
||||
consumerReassociations.begin(), consumerReassociations.end(), 0,
|
||||
[](size_t all, ReassociationIndicesRef indices) {
|
||||
return all + indices.size();
|
||||
});
|
||||
if (producerReassociations.size() != consumerDims)
|
||||
return llvm::None;
|
||||
|
||||
unsigned currDim = 0;
|
||||
SmallVector<ReassociationIndices> reassociationMaps;
|
||||
for (AffineMap rhs : mapsConsumer) {
|
||||
for (ReassociationIndicesRef consumerIndices : consumerReassociations) {
|
||||
ReassociationIndices reassociations;
|
||||
for (AffineExpr rhsExpr : rhs.getResults()) {
|
||||
AffineDimExpr dimExpr = rhsExpr.cast<AffineDimExpr>();
|
||||
for (int i = 0, e = mapsProducer[dimExpr.getPosition()].getNumResults();
|
||||
i < e; ++i)
|
||||
reassociations.push_back(currDim++);
|
||||
for (int64_t consumerIndex : consumerIndices) {
|
||||
for (int64_t producerIndex : producerReassociations[consumerIndex])
|
||||
reassociations.push_back(producerIndex);
|
||||
}
|
||||
reassociationMaps.push_back(std::move(reassociations));
|
||||
composedIndices.push_back(std::move(reassociations));
|
||||
}
|
||||
return reassociationMaps;
|
||||
return composedIndices;
|
||||
}
|
||||
|
||||
bool mlir::isReassociationValid(ArrayRef<AffineMap> reassociation,
|
||||
|
|
Loading…
Reference in New Issue