forked from OSchip/llvm-project
[mlir] Move common reshapeops-related code to ReshapeOpsUtils.h.
This is a first step to move (Tensor)Expand/CollapseShapeOp to tensor/memref dialects. Differential Revision: https://reviews.llvm.org/D105547
This commit is contained in:
parent
d0b282e10b
commit
6412a13539
|
@ -12,6 +12,7 @@
|
||||||
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
||||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
|
||||||
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
|
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
|
||||||
#include "mlir/IR/AffineExpr.h"
|
#include "mlir/IR/AffineExpr.h"
|
||||||
#include "mlir/IR/AffineMap.h"
|
#include "mlir/IR/AffineMap.h"
|
||||||
|
@ -52,16 +53,6 @@ using LoopRangeBuilder =
|
||||||
/// provide an op-specified hook so that Linalg ops may override the behavior.
|
/// provide an op-specified hook so that Linalg ops may override the behavior.
|
||||||
LoopRangeBuilder defaultLoopRangesBuilder(LinalgOp op);
|
LoopRangeBuilder defaultLoopRangesBuilder(LinalgOp op);
|
||||||
|
|
||||||
using ReassociationIndices = SmallVector<int64_t, 2>;
|
|
||||||
using ReassociationIndicesRef = ArrayRef<int64_t>;
|
|
||||||
using ReassociationExprs = SmallVector<AffineExpr, 2>;
|
|
||||||
|
|
||||||
/// Return the reassociations maps to use to reshape given the source type and
|
|
||||||
/// the target type when possible. Return llvm::None when this computation
|
|
||||||
/// failed.
|
|
||||||
Optional<SmallVector<ReassociationIndices>>
|
|
||||||
getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType);
|
|
||||||
|
|
||||||
/// Returns the name mangled library call name to disambiguate between different
|
/// Returns the name mangled library call name to disambiguate between different
|
||||||
/// overloads at the C level. The name mangling scheme is basic and uses MLIR
|
/// overloads at the C level. The name mangling scheme is basic and uses MLIR
|
||||||
/// type names:
|
/// type names:
|
||||||
|
|
|
@ -0,0 +1,266 @@
|
||||||
|
//===- RehshapeOpsUtils.h - Utilities used by reshape ops --*- C++ -*------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// This header file defines utilities and common canonicalization patterns for
|
||||||
|
// reshape operations.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
|
||||||
|
#define MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
|
||||||
|
|
||||||
|
#include "mlir/IR/OpImplementation.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
#include "mlir/Support/LLVM.h"
|
||||||
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
|
||||||
|
using ReassociationIndices = SmallVector<int64_t, 2>;
|
||||||
|
using ReassociationIndicesRef = ArrayRef<int64_t>;
|
||||||
|
using ReassociationExprs = SmallVector<AffineExpr, 2>;
|
||||||
|
|
||||||
|
/// Attribute name for the ArrayAttr which encodes reassociation indices.
|
||||||
|
constexpr StringRef getReassociationAttrName();
|
||||||
|
|
||||||
|
/// Collapse reassociation maps that are used in pair of reshape ops where one
|
||||||
|
/// is a producer and other is the consumer. Only valid to use this method when
|
||||||
|
/// both the producer and consumer are collapsing dimensions or both are
|
||||||
|
/// expanding dimensions.
|
||||||
|
///
|
||||||
|
/// For example,
|
||||||
|
/// mapsProducer = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>,
|
||||||
|
/// affine_map<(d0, d1, d2, d3, d4) -> (d2)>,
|
||||||
|
/// affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>]
|
||||||
|
/// mapsConsumer = [affine_map<(d0, d1, d2) -> (d0, d1)>,
|
||||||
|
/// affine_map<(d0, d1, d2) -> (d2)>]
|
||||||
|
///
|
||||||
|
/// is folded into
|
||||||
|
///
|
||||||
|
/// result = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
|
||||||
|
/// affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>]
|
||||||
|
/// TODO: Use reassociation indices instead of affine maps here.
|
||||||
|
Optional<SmallVector<ReassociationIndices>>
|
||||||
|
collapseReassociationIndices(ArrayRef<AffineMap> mapsProducer,
|
||||||
|
ArrayRef<AffineMap> mapsConsumer,
|
||||||
|
MLIRContext *context);
|
||||||
|
|
||||||
|
/// Return the reassociations maps to use to reshape given the source type and
|
||||||
|
/// the target type when possible. Return llvm::None when this computation
|
||||||
|
/// failed.
|
||||||
|
Optional<SmallVector<ReassociationIndices>>
|
||||||
|
getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType);
|
||||||
|
|
||||||
|
/// Return true if the reassociation specification is valid, false otherwise.
|
||||||
|
/// When false, the `invalidIndex` integer pointer is optionally filled with the
|
||||||
|
/// index of the offending reassociation map.
|
||||||
|
bool isReassociationValid(ArrayRef<AffineMap> reassociation,
|
||||||
|
int *invalidIndex = nullptr);
|
||||||
|
|
||||||
|
/// Parse a reshape-like op, i.e. linalg::(Tensor)ExpandShapeOp,
|
||||||
|
/// linalg::(Tensor)CollapseShapeOp.
|
||||||
|
ParseResult parseReshapeLikeOp(OpAsmParser &parser, OperationState &result);
|
||||||
|
|
||||||
|
/// Print a reshape-like op, i.e. linalg::(Tensor)ExpandShapeOp,
|
||||||
|
/// linalg::(Tensor)CollapseShapeOp.
|
||||||
|
template <typename ReshapeLikeOp>
|
||||||
|
void printReshapeOp(OpAsmPrinter &p, ReshapeLikeOp op) {
|
||||||
|
p << op.getOperationName() << ' ' << op.src() << " [";
|
||||||
|
|
||||||
|
llvm::interleaveComma(op.reassociation(), p, [&](const Attribute &attr) {
|
||||||
|
p << '[';
|
||||||
|
auto arrayAttr = attr.template cast<ArrayAttr>();
|
||||||
|
llvm::interleaveComma(arrayAttr, p, [&](const Attribute &attr) {
|
||||||
|
p << attr.cast<IntegerAttr>().getInt();
|
||||||
|
});
|
||||||
|
p << ']';
|
||||||
|
});
|
||||||
|
|
||||||
|
p << "] ";
|
||||||
|
p.printOptionalAttrDict(op->getAttrs(),
|
||||||
|
/*elidedAttrs=*/{op.getReassociationAttrName()});
|
||||||
|
p << ": " << op.src().getType() << " into " << op.getType();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename ReshapeOpTy, typename InverseReshapeOpTy>
|
||||||
|
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
|
||||||
|
ArrayRef<Attribute> operands) {
|
||||||
|
// Fold producer-consumer reshape ops that where the operand type of the
|
||||||
|
// producer is same as the return type of the consumer.
|
||||||
|
auto reshapeSrcOp =
|
||||||
|
reshapeOp.src().template getDefiningOp<InverseReshapeOpTy>();
|
||||||
|
if (reshapeSrcOp && reshapeSrcOp.getSrcType() == reshapeOp.getResultType())
|
||||||
|
return reshapeSrcOp.src();
|
||||||
|
// Reshape of a constant can be replaced with a new constant.
|
||||||
|
if (auto elements = operands.front().dyn_cast_or_null<DenseElementsAttr>()) {
|
||||||
|
return elements.reshape(
|
||||||
|
reshapeOp.getResult().getType().template cast<ShapedType>());
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Common verifier for reshape-like types. Fills `expandedType` and
|
||||||
|
///`collapsedType` with the proper `src` or `result` type.
|
||||||
|
template <typename Op, typename T>
|
||||||
|
static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType,
|
||||||
|
T collapsedType, bool isExpansion) {
|
||||||
|
unsigned expandedRank = expandedType.getRank();
|
||||||
|
unsigned collapsedRank = collapsedType.getRank();
|
||||||
|
if (expandedRank < collapsedRank)
|
||||||
|
return op.emitOpError("expected the type ")
|
||||||
|
<< expandedType
|
||||||
|
<< " to have higher rank than the type = " << collapsedType;
|
||||||
|
if (expandedRank == 0)
|
||||||
|
return op.emitOpError("expected non-zero memref ranks");
|
||||||
|
if (expandedRank == collapsedRank)
|
||||||
|
return op.emitOpError("expected to collapse or expand dims");
|
||||||
|
|
||||||
|
if (collapsedRank == 0) {
|
||||||
|
// If collapsed rank is 0, then expanded type must be static shaped and of
|
||||||
|
// sizes 1.
|
||||||
|
if (llvm::any_of(expandedType.getShape(),
|
||||||
|
[](int64_t dim) -> bool { return dim != 1; }))
|
||||||
|
return op.emitOpError("invalid to reshape tensor/memref with non-unit "
|
||||||
|
"extent dimensions to zero-rank tensor/memref");
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
if (collapsedRank != op.reassociation().size())
|
||||||
|
return op.emitOpError("expected rank of the collapsed type(")
|
||||||
|
<< collapsedRank << ") to be the number of reassociation maps("
|
||||||
|
<< op.reassociation().size() << ")";
|
||||||
|
auto maps = op.getReassociationMaps();
|
||||||
|
for (auto it : llvm::enumerate(maps))
|
||||||
|
if (it.value().getNumDims() != expandedRank)
|
||||||
|
return op.emitOpError("expected reassociation map #")
|
||||||
|
<< it.index() << " of same rank as expanded memref("
|
||||||
|
<< expandedRank << "), but got " << it.value().getNumDims();
|
||||||
|
int invalidIdx = 0;
|
||||||
|
if (!isReassociationValid(maps, &invalidIdx))
|
||||||
|
return op.emitOpError("expected reassociation map #")
|
||||||
|
<< invalidIdx << " to be valid and contiguous";
|
||||||
|
return verifyReshapeLikeShapes(op, collapsedType, expandedType, isExpansion);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Verify that shapes of the reshaped types using following rules
|
||||||
|
/// 1) if a dimension in the collapsed type is static, then the corresponding
|
||||||
|
/// dimensions in the expanded shape should be
|
||||||
|
/// a) static
|
||||||
|
/// b) the product should be same as the collaped shape.
|
||||||
|
/// 2) if a dimension in the collaped type is dynamic, one and only one of the
|
||||||
|
/// corresponding dimensions in the expanded type should be dynamic. This
|
||||||
|
/// rule is only needed with reshape operations that are expanding.
|
||||||
|
template <typename OpTy>
|
||||||
|
static LogicalResult verifyReshapeLikeShapes(OpTy op, ShapedType collapsedType,
|
||||||
|
ShapedType expandedType,
|
||||||
|
bool isExpandingReshape) {
|
||||||
|
ArrayRef<int64_t> collapsedShape = collapsedType.getShape();
|
||||||
|
ArrayRef<int64_t> expandedShape = expandedType.getShape();
|
||||||
|
unsigned expandedDimStart = 0;
|
||||||
|
for (auto map : llvm::enumerate(op.getReassociationMaps())) {
|
||||||
|
Optional<int64_t> dynamicShape;
|
||||||
|
int64_t linearizedStaticShape = 1;
|
||||||
|
for (auto dim : llvm::enumerate(expandedShape.slice(
|
||||||
|
expandedDimStart, map.value().getNumResults()))) {
|
||||||
|
if (ShapedType::isDynamic(dim.value())) {
|
||||||
|
if (isExpandingReshape && dynamicShape) {
|
||||||
|
return op->emitOpError("invalid to have a single dimension (")
|
||||||
|
<< map.index() << ") expanded into multiple dynamic dims ("
|
||||||
|
<< expandedDimStart + dynamicShape.getValue() << ","
|
||||||
|
<< expandedDimStart + dim.index() << ")";
|
||||||
|
}
|
||||||
|
dynamicShape = dim.index();
|
||||||
|
} else {
|
||||||
|
linearizedStaticShape *= dim.value();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (dynamicShape) {
|
||||||
|
if (!ShapedType::isDynamic(collapsedShape[map.index()])) {
|
||||||
|
return op->emitOpError("expected dimension ")
|
||||||
|
<< map.index()
|
||||||
|
<< " of collapsed type to be dynamic since one or more of the "
|
||||||
|
"corresponding dimensions in the expanded type is dynamic";
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (collapsedShape[map.index()] != linearizedStaticShape) {
|
||||||
|
return op->emitOpError("expected dimension ")
|
||||||
|
<< map.index() << " of collapsed type to be static value of "
|
||||||
|
<< linearizedStaticShape << " ";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
expandedDimStart += map.value().getNumResults();
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Pattern to collapse producer/consumer reshape ops that are both collapsing
|
||||||
|
/// dimensions or are both expanding dimensions.
|
||||||
|
template <typename ReshapeOpTy>
|
||||||
|
struct CollapseReshapeOps : public OpRewritePattern<ReshapeOpTy> {
|
||||||
|
using OpRewritePattern<ReshapeOpTy>::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
auto srcReshapeOp = reshapeOp.src().template getDefiningOp<ReshapeOpTy>();
|
||||||
|
if (!srcReshapeOp)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
ShapedType resultType = reshapeOp.getResultType();
|
||||||
|
Optional<SmallVector<ReassociationIndices>> reassociationIndices =
|
||||||
|
collapseReassociationIndices(srcReshapeOp.getReassociationMaps(),
|
||||||
|
reshapeOp.getReassociationMaps(),
|
||||||
|
rewriter.getContext());
|
||||||
|
if (!reassociationIndices)
|
||||||
|
return failure();
|
||||||
|
rewriter.replaceOpWithNewOp<ReshapeOpTy>(
|
||||||
|
reshapeOp, resultType, srcReshapeOp.src(), *reassociationIndices);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Pattern to collapse producer/consumer reshape ops that are both collapsing
|
||||||
|
/// dimensions or are both expanding dimensions.
|
||||||
|
template <typename ReshapeOpTy, typename InverseReshapeOpTy>
|
||||||
|
struct CollapseMixedReshapeOps : public OpRewritePattern<ReshapeOpTy> {
|
||||||
|
using OpRewritePattern<ReshapeOpTy>::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
auto srcReshapeOp =
|
||||||
|
reshapeOp.src().template getDefiningOp<InverseReshapeOpTy>();
|
||||||
|
if (!srcReshapeOp)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
ShapedType srcReshapeSrcType = srcReshapeOp.getSrcType();
|
||||||
|
ShapedType intermediateType = reshapeOp.getSrcType();
|
||||||
|
ShapedType resultType = reshapeOp.getResultType();
|
||||||
|
|
||||||
|
// If the source reshape can be collapsed/expanded into the target reshape
|
||||||
|
// they can still be folded. This can only be reasoned about statically
|
||||||
|
// for cases where
|
||||||
|
// - either all shapes are static, or
|
||||||
|
// - The number of dynamic dimensions matches in the source of source and
|
||||||
|
// result with all other dimensions being 1.
|
||||||
|
Optional<SmallVector<ReassociationIndices>> reassociationIndices =
|
||||||
|
getReassociationIndicesForReshape(srcReshapeSrcType, resultType);
|
||||||
|
if (!reassociationIndices)
|
||||||
|
return failure();
|
||||||
|
bool originalOpExpands =
|
||||||
|
intermediateType.getRank() > srcReshapeSrcType.getRank();
|
||||||
|
bool resultingOpExpands =
|
||||||
|
resultType.getRank() > srcReshapeSrcType.getRank();
|
||||||
|
if (!(resultingOpExpands ^ originalOpExpands))
|
||||||
|
rewriter.replaceOpWithNewOp<InverseReshapeOpTy>(
|
||||||
|
reshapeOp, resultType, srcReshapeOp.src(), *reassociationIndices);
|
||||||
|
else
|
||||||
|
rewriter.replaceOpWithNewOp<ReshapeOpTy>(
|
||||||
|
reshapeOp, resultType, srcReshapeOp.src(), *reassociationIndices);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
|
|
@ -10,6 +10,7 @@ add_mlir_conversion_library(MLIRTosaToLinalg
|
||||||
MLIRConversionPassIncGen
|
MLIRConversionPassIncGen
|
||||||
|
|
||||||
LINK_LIBS PUBLIC
|
LINK_LIBS PUBLIC
|
||||||
|
MLIRDialectUtils
|
||||||
MLIRIR
|
MLIRIR
|
||||||
MLIRLinalg
|
MLIRLinalg
|
||||||
MLIRLinalgUtils
|
MLIRLinalgUtils
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||||
|
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
|
||||||
#include "mlir/IR/Matchers.h"
|
#include "mlir/IR/Matchers.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
@ -1120,8 +1121,7 @@ public:
|
||||||
(operandTy.getRank() > resultTy.getRank() ? resultTy.getShape()
|
(operandTy.getRank() > resultTy.getRank() ? resultTy.getShape()
|
||||||
: operandTy.getShape());
|
: operandTy.getShape());
|
||||||
unsigned currSrcDim = 0, currDstDim = 0;
|
unsigned currSrcDim = 0, currDstDim = 0;
|
||||||
SmallVector<linalg::ReassociationExprs, 4> reassociationMap(
|
SmallVector<ReassociationExprs, 4> reassociationMap(collapsedShape.size());
|
||||||
collapsedShape.size());
|
|
||||||
|
|
||||||
// First scan all dimensions in the source shapes to see whether we have a
|
// First scan all dimensions in the source shapes to see whether we have a
|
||||||
// perfect case where consecutive dimensions in source are collapsed. For
|
// perfect case where consecutive dimensions in source are collapsed. For
|
||||||
|
@ -1176,11 +1176,11 @@ public:
|
||||||
std::accumulate(expandedShape.begin(), expandedShape.end(), 1,
|
std::accumulate(expandedShape.begin(), expandedShape.end(), 1,
|
||||||
std::multiplies<int64_t>());
|
std::multiplies<int64_t>());
|
||||||
auto elemTy = operandTy.getElementType();
|
auto elemTy = operandTy.getElementType();
|
||||||
SmallVector<linalg::ReassociationExprs, 4> collapsingMap = {
|
SmallVector<ReassociationExprs, 4> collapsingMap = {
|
||||||
// Use operandTy here because we need to collapse all operands
|
// Use operandTy here because we need to collapse all operands
|
||||||
// dimensions.
|
// dimensions.
|
||||||
getIdentityExprs(operandTy.getShape().size())};
|
getIdentityExprs(operandTy.getShape().size())};
|
||||||
SmallVector<linalg::ReassociationExprs, 4> expandingMap = {
|
SmallVector<ReassociationExprs, 4> expandingMap = {
|
||||||
// Use resultTy here because we need to expand to all result
|
// Use resultTy here because we need to expand to all result
|
||||||
// dimensions.
|
// dimensions.
|
||||||
getIdentityExprs(resultTy.getShape().size())};
|
getIdentityExprs(resultTy.getShape().size())};
|
||||||
|
|
|
@ -1069,338 +1069,20 @@ OpFoldResult PadTensorOp::fold(ArrayRef<Attribute>) {
|
||||||
// ReshapeOp
|
// ReshapeOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
Optional<SmallVector<ReassociationIndices>>
|
|
||||||
mlir::linalg::getReassociationIndicesForReshape(ShapedType sourceType,
|
|
||||||
ShapedType targetType) {
|
|
||||||
// Make the sourceType greater rank than the targetType. If they are same
|
|
||||||
// rank, then its an unsupported reshape op.
|
|
||||||
if (sourceType.getRank() == targetType.getRank())
|
|
||||||
return llvm::None;
|
|
||||||
if (sourceType.getRank() < targetType.getRank())
|
|
||||||
std::swap(sourceType, targetType);
|
|
||||||
|
|
||||||
ArrayRef<int64_t> sourceShape = sourceType.getShape();
|
|
||||||
ArrayRef<int64_t> targetShape = targetType.getShape();
|
|
||||||
unsigned sourceDim = 0;
|
|
||||||
SmallVector<ReassociationIndices> reassociationMap;
|
|
||||||
reassociationMap.reserve(targetType.getRank());
|
|
||||||
|
|
||||||
ReassociationIndices currIndices;
|
|
||||||
int64_t prodOfCollapsedDims = 1;
|
|
||||||
while (sourceDim < sourceShape.size()) {
|
|
||||||
unsigned targetDim = reassociationMap.size();
|
|
||||||
|
|
||||||
// If all the dimensions of the targetShape are exhausted, then the
|
|
||||||
// remaining dims in the source shape must be all 1s. So for such cases, set
|
|
||||||
// 1 as the target shape. The actual reassociation indices will be handled
|
|
||||||
// later.
|
|
||||||
int64_t currTargetShape =
|
|
||||||
(targetDim < targetType.getRank() ? targetShape[targetDim] : 1);
|
|
||||||
while (sourceShape[sourceDim] != ShapedType::kDynamicSize &&
|
|
||||||
prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape &&
|
|
||||||
sourceDim < sourceShape.size()) {
|
|
||||||
prodOfCollapsedDims *= sourceShape[sourceDim];
|
|
||||||
currIndices.push_back(sourceDim++);
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the current expanded dimension is dynamic, then the collapsed
|
|
||||||
// dimensions should also be dynamic and product of all previous unprocessed
|
|
||||||
// dimensions of the expanded shape should be 1.
|
|
||||||
if (sourceShape[sourceDim] == ShapedType::kDynamicSize &&
|
|
||||||
(currTargetShape != ShapedType::kDynamicSize ||
|
|
||||||
prodOfCollapsedDims != 1))
|
|
||||||
return llvm::None;
|
|
||||||
|
|
||||||
// If the collapsed dim is dynamic, the current expanded dim should also
|
|
||||||
// be dynamic.
|
|
||||||
if (currTargetShape == ShapedType::kDynamicSize &&
|
|
||||||
sourceShape[sourceDim] != ShapedType::kDynamicSize)
|
|
||||||
return llvm::None;
|
|
||||||
|
|
||||||
// For static shapes, if the product of dimensions of the expanded shape
|
|
||||||
// should match the collapsed dimension shape.
|
|
||||||
if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
|
|
||||||
return llvm::None;
|
|
||||||
|
|
||||||
currIndices.push_back(sourceDim++);
|
|
||||||
// If the reassociation is empty but the currIndices is not, this by
|
|
||||||
// definition is folding unit-dimensions with the result being scalar type.
|
|
||||||
// So only append the `currIndices` if reassociation map is not empty.
|
|
||||||
if (targetDim == targetShape.size()) {
|
|
||||||
if (!reassociationMap.empty() && !currIndices.empty())
|
|
||||||
reassociationMap.back().append(currIndices.begin(), currIndices.end());
|
|
||||||
// Break out of the loops. We should be done here.
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
reassociationMap.emplace_back(ReassociationIndices{});
|
|
||||||
std::swap(reassociationMap.back(), currIndices);
|
|
||||||
prodOfCollapsedDims = 1;
|
|
||||||
}
|
|
||||||
// All the dimensions in the two shapes must have been processed.
|
|
||||||
if (reassociationMap.size() != targetShape.size() ||
|
|
||||||
sourceDim != sourceShape.size())
|
|
||||||
return llvm::None;
|
|
||||||
return reassociationMap;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename ReshapeLikeOp>
|
|
||||||
static void print(OpAsmPrinter &p, ReshapeLikeOp op) {
|
|
||||||
p << op.getOperationName() << ' ' << op.src() << " [";
|
|
||||||
|
|
||||||
llvm::interleaveComma(op.reassociation(), p, [&](const Attribute &attr) {
|
|
||||||
p << '[';
|
|
||||||
auto arrayAttr = attr.template cast<ArrayAttr>();
|
|
||||||
llvm::interleaveComma(arrayAttr, p, [&](const Attribute &attr) {
|
|
||||||
p << attr.cast<IntegerAttr>().getInt();
|
|
||||||
});
|
|
||||||
p << ']';
|
|
||||||
});
|
|
||||||
|
|
||||||
p << "] ";
|
|
||||||
p.printOptionalAttrDict(op->getAttrs(),
|
|
||||||
/*elidedAttrs=*/{op.getReassociationAttrName()});
|
|
||||||
p << ": " << op.src().getType() << " into " << op.getType();
|
|
||||||
}
|
|
||||||
|
|
||||||
static void print(OpAsmPrinter &p, linalg::ExpandShapeOp op) {
|
static void print(OpAsmPrinter &p, linalg::ExpandShapeOp op) {
|
||||||
print<linalg::ExpandShapeOp>(p, op);
|
::mlir::printReshapeOp<linalg::ExpandShapeOp>(p, op);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void print(OpAsmPrinter &p, linalg::CollapseShapeOp op) {
|
static void print(OpAsmPrinter &p, linalg::CollapseShapeOp op) {
|
||||||
print<linalg::CollapseShapeOp>(p, op);
|
::mlir::printReshapeOp<linalg::CollapseShapeOp>(p, op);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void print(OpAsmPrinter &p, linalg::TensorExpandShapeOp op) {
|
static void print(OpAsmPrinter &p, linalg::TensorExpandShapeOp op) {
|
||||||
print<linalg::TensorExpandShapeOp>(p, op);
|
::mlir::printReshapeOp<linalg::TensorExpandShapeOp>(p, op);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void print(OpAsmPrinter &p, linalg::TensorCollapseShapeOp op) {
|
static void print(OpAsmPrinter &p, linalg::TensorCollapseShapeOp op) {
|
||||||
print<linalg::TensorCollapseShapeOp>(p, op);
|
::mlir::printReshapeOp<linalg::TensorCollapseShapeOp>(p, op);
|
||||||
}
|
|
||||||
|
|
||||||
static constexpr StringRef getReassociationAttrName() {
|
|
||||||
return "reassociation";
|
|
||||||
}
|
|
||||||
|
|
||||||
static ParseResult parseReshapeLikeOp(OpAsmParser &parser,
|
|
||||||
OperationState &result) {
|
|
||||||
// Parse the operand.
|
|
||||||
OpAsmParser::OperandType src;
|
|
||||||
if (parser.parseOperand(src))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
// Parse reassociation indices.
|
|
||||||
Builder &b = parser.getBuilder();
|
|
||||||
SmallVector<Attribute, 4> reassociation;
|
|
||||||
if (parser.parseLSquare())
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
while (true) {
|
|
||||||
if (succeeded(parser.parseOptionalRSquare()))
|
|
||||||
break;
|
|
||||||
if (parser.parseLSquare())
|
|
||||||
return failure();
|
|
||||||
SmallVector<int64_t> indices;
|
|
||||||
while (true) {
|
|
||||||
int64_t index;
|
|
||||||
if (parser.parseInteger(index))
|
|
||||||
return failure();
|
|
||||||
indices.push_back(index);
|
|
||||||
|
|
||||||
if (succeeded(parser.parseOptionalComma()))
|
|
||||||
continue;
|
|
||||||
if (failed(parser.parseRSquare()))
|
|
||||||
return failure();
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
reassociation.push_back(b.getI64ArrayAttr(indices));
|
|
||||||
if (succeeded(parser.parseOptionalComma()))
|
|
||||||
continue;
|
|
||||||
if (failed(parser.parseRSquare()))
|
|
||||||
return failure();
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
result.addAttribute(getReassociationAttrName(),
|
|
||||||
b.getArrayAttr(reassociation));
|
|
||||||
|
|
||||||
// Parse optional attributes.
|
|
||||||
parser.parseOptionalAttrDict(result.attributes);
|
|
||||||
|
|
||||||
// Parse types.
|
|
||||||
Type srcType;
|
|
||||||
Type resultType;
|
|
||||||
if (parser.parseColon() || parser.parseType(srcType) ||
|
|
||||||
parser.resolveOperand(src, srcType, result.operands) ||
|
|
||||||
parser.parseKeyword("into") || parser.parseType(resultType))
|
|
||||||
return failure();
|
|
||||||
result.addTypes(resultType);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Collapse reassociation maps that are used in pair of reshape ops where one
|
|
||||||
/// is a producer and other is the consumer. Only valid to use this method when
|
|
||||||
/// both the producer and consumer are collapsing dimensions or both are
|
|
||||||
/// expanding dimensions.
|
|
||||||
///
|
|
||||||
/// For example,
|
|
||||||
/// mapsProducer = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>,
|
|
||||||
/// affine_map<(d0, d1, d2, d3, d4) -> (d2)>,
|
|
||||||
/// affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>]
|
|
||||||
/// mapsConsumer = [affine_map<(d0, d1, d2) -> (d0, d1)>,
|
|
||||||
/// affine_map<(d0, d1, d2) -> (d2)>]
|
|
||||||
///
|
|
||||||
/// is folded into
|
|
||||||
///
|
|
||||||
/// result = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
|
|
||||||
/// affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>]
|
|
||||||
static Optional<SmallVector<ReassociationIndices>>
|
|
||||||
collapseReassociationIndices(ArrayRef<AffineMap> mapsProducer,
|
|
||||||
ArrayRef<AffineMap> mapsConsumer,
|
|
||||||
MLIRContext *context) {
|
|
||||||
// Make the producer the larger sized vector. If they are of same size, the
|
|
||||||
// resulting reshape is not a supported reshape op.
|
|
||||||
if (mapsProducer.size() == mapsConsumer.size())
|
|
||||||
return llvm::None;
|
|
||||||
if (mapsProducer.size() < mapsConsumer.size())
|
|
||||||
std::swap(mapsProducer, mapsConsumer);
|
|
||||||
|
|
||||||
// Handle the corner case of the result being a rank 0 shaped type. Return an
|
|
||||||
// empty reassociation.
|
|
||||||
if (mapsConsumer.empty())
|
|
||||||
return SmallVector<ReassociationIndices>{};
|
|
||||||
if (mapsProducer.size() != mapsConsumer[0].getNumDims())
|
|
||||||
return llvm::None;
|
|
||||||
|
|
||||||
unsigned currDim = 0;
|
|
||||||
SmallVector<ReassociationIndices> reassociationMaps;
|
|
||||||
for (AffineMap rhs : mapsConsumer) {
|
|
||||||
ReassociationIndices reassociations;
|
|
||||||
for (AffineExpr rhsExpr : rhs.getResults()) {
|
|
||||||
AffineDimExpr dimExpr = rhsExpr.cast<AffineDimExpr>();
|
|
||||||
for (int i = 0, e = mapsProducer[dimExpr.getPosition()].getNumResults();
|
|
||||||
i < e; ++i)
|
|
||||||
reassociations.push_back(currDim++);
|
|
||||||
}
|
|
||||||
reassociationMaps.push_back(std::move(reassociations));
|
|
||||||
}
|
|
||||||
return reassociationMaps;
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
/// Pattern to collapse producer/consumer reshape ops that are both collapsing
|
|
||||||
/// dimensions or are both expanding dimensions.
|
|
||||||
template <typename ReshapeOpTy>
|
|
||||||
struct CollapseReshapeOps : public OpRewritePattern<ReshapeOpTy> {
|
|
||||||
using OpRewritePattern<ReshapeOpTy>::OpRewritePattern;
|
|
||||||
LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp,
|
|
||||||
PatternRewriter &rewriter) const override {
|
|
||||||
auto srcReshapeOp = reshapeOp.src().template getDefiningOp<ReshapeOpTy>();
|
|
||||||
if (!srcReshapeOp)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
ShapedType resultType = reshapeOp.getResultType();
|
|
||||||
Optional<SmallVector<ReassociationIndices>> reassociationIndices =
|
|
||||||
collapseReassociationIndices(srcReshapeOp.getReassociationMaps(),
|
|
||||||
reshapeOp.getReassociationMaps(),
|
|
||||||
rewriter.getContext());
|
|
||||||
if (!reassociationIndices)
|
|
||||||
return failure();
|
|
||||||
rewriter.replaceOpWithNewOp<ReshapeOpTy>(
|
|
||||||
reshapeOp, resultType, srcReshapeOp.src(), *reassociationIndices);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Pattern to collapse producer/consumer reshape ops that are both collapsing
|
|
||||||
/// dimensions or are both expanding dimensions.
|
|
||||||
template <typename ReshapeOpTy, typename InverseReshapeOpTy>
|
|
||||||
struct CollapseMixedReshapeOps : public OpRewritePattern<ReshapeOpTy> {
|
|
||||||
using OpRewritePattern<ReshapeOpTy>::OpRewritePattern;
|
|
||||||
LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp,
|
|
||||||
PatternRewriter &rewriter) const override {
|
|
||||||
auto srcReshapeOp =
|
|
||||||
reshapeOp.src().template getDefiningOp<InverseReshapeOpTy>();
|
|
||||||
if (!srcReshapeOp)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
ShapedType srcReshapeSrcType = srcReshapeOp.getSrcType();
|
|
||||||
ShapedType intermediateType = reshapeOp.getSrcType();
|
|
||||||
ShapedType resultType = reshapeOp.getResultType();
|
|
||||||
|
|
||||||
// If the source reshape can be collapsed/expanded into the target reshape
|
|
||||||
// they can still be folded. This can only be reasoned about statically
|
|
||||||
// for cases where
|
|
||||||
// - either all shapes are static, or
|
|
||||||
// - The number of dynamic dimensions matches in the source of source and
|
|
||||||
// result with all other dimensions being 1.
|
|
||||||
Optional<SmallVector<ReassociationIndices>> reassociationIndices =
|
|
||||||
getReassociationIndicesForReshape(srcReshapeSrcType, resultType);
|
|
||||||
if (!reassociationIndices)
|
|
||||||
return failure();
|
|
||||||
bool originalOpExpands =
|
|
||||||
intermediateType.getRank() > srcReshapeSrcType.getRank();
|
|
||||||
bool resultingOpExpands =
|
|
||||||
resultType.getRank() > srcReshapeSrcType.getRank();
|
|
||||||
if (!(resultingOpExpands ^ originalOpExpands))
|
|
||||||
rewriter.replaceOpWithNewOp<InverseReshapeOpTy>(
|
|
||||||
reshapeOp, resultType, srcReshapeOp.src(), *reassociationIndices);
|
|
||||||
else
|
|
||||||
rewriter.replaceOpWithNewOp<ReshapeOpTy>(
|
|
||||||
reshapeOp, resultType, srcReshapeOp.src(), *reassociationIndices);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
template <typename ReshapeOpTy, typename InverseReshapeOpTy>
|
|
||||||
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
|
|
||||||
ArrayRef<Attribute> operands) {
|
|
||||||
// Fold producer-consumer reshape ops that where the operand type of the
|
|
||||||
// producer is same as the return type of the consumer.
|
|
||||||
auto reshapeSrcOp =
|
|
||||||
reshapeOp.src().template getDefiningOp<InverseReshapeOpTy>();
|
|
||||||
if (reshapeSrcOp && reshapeSrcOp.getSrcType() == reshapeOp.getResultType())
|
|
||||||
return reshapeSrcOp.src();
|
|
||||||
// Reshape of a constant can be replaced with a new constant.
|
|
||||||
if (auto elements = operands.front().dyn_cast_or_null<DenseElementsAttr>()) {
|
|
||||||
return elements.reshape(
|
|
||||||
reshapeOp.getResult().getType().template cast<ShapedType>());
|
|
||||||
}
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Return true if the reassociation specification is valid, false otherwise.
|
|
||||||
/// When false, the `invalidIndex` integer pointer is optionally filled with the
|
|
||||||
/// index of the offending reassociation map.
|
|
||||||
static bool isReassociationValid(ArrayRef<AffineMap> reassociation,
|
|
||||||
int *invalidIndex = nullptr) {
|
|
||||||
if (reassociation.empty())
|
|
||||||
return true;
|
|
||||||
unsigned nDims = reassociation[0].getNumDims();
|
|
||||||
unsigned nextExpectedDim = 0;
|
|
||||||
for (auto it : llvm::enumerate(reassociation)) {
|
|
||||||
auto m = it.value();
|
|
||||||
if (m.getNumDims() != nDims || m.getNumSymbols() != 0) {
|
|
||||||
if (invalidIndex)
|
|
||||||
*invalidIndex = it.index();
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
for (auto e : m.getResults()) {
|
|
||||||
auto d = e.dyn_cast<AffineDimExpr>();
|
|
||||||
if (!d || d.getPosition() != nextExpectedDim++) {
|
|
||||||
if (invalidIndex)
|
|
||||||
*invalidIndex = it.index();
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (nextExpectedDim != nDims) {
|
|
||||||
if (invalidIndex)
|
|
||||||
*invalidIndex = reassociation.size() - 1;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Detect whether memref dims [dim, dim + extent) can be reshaped without
|
/// Detect whether memref dims [dim, dim + extent) can be reshaped without
|
||||||
|
@ -1736,106 +1418,12 @@ void mlir::linalg::CollapseShapeOp::build(
|
||||||
|
|
||||||
Value mlir::linalg::CollapseShapeOp::getViewSource() { return src(); }
|
Value mlir::linalg::CollapseShapeOp::getViewSource() { return src(); }
|
||||||
|
|
||||||
/// Verify that shapes of the reshaped types using following rules
|
template <typename ReshapeOp,
|
||||||
/// 1) if a dimension in the collapsed type is static, then the corresponding
|
bool isExpansion = std::is_same<ReshapeOp, ExpandShapeOp>::value>
|
||||||
/// dimensions in the expanded shape should be
|
static LogicalResult verifyReshapeOp(ReshapeOp op, MemRefType expandedType,
|
||||||
/// a) static
|
|
||||||
/// b) the product should be same as the collaped shape.
|
|
||||||
/// 2) if a dimension in the collaped type is dynamic, one and only one of the
|
|
||||||
/// corresponding dimensions in the expanded type should be dynamic. This
|
|
||||||
/// rule is only needed with reshape operations that are expanding.
|
|
||||||
template <typename OpTy>
|
|
||||||
static LogicalResult verifyReshapeLikeShapes(OpTy op, ShapedType collapsedType,
|
|
||||||
ShapedType expandedType,
|
|
||||||
bool isExpandingReshape) {
|
|
||||||
ArrayRef<int64_t> collapsedShape = collapsedType.getShape();
|
|
||||||
ArrayRef<int64_t> expandedShape = expandedType.getShape();
|
|
||||||
unsigned expandedDimStart = 0;
|
|
||||||
for (auto map : llvm::enumerate(op.getReassociationMaps())) {
|
|
||||||
Optional<int64_t> dynamicShape;
|
|
||||||
int64_t linearizedStaticShape = 1;
|
|
||||||
for (auto dim : llvm::enumerate(expandedShape.slice(
|
|
||||||
expandedDimStart, map.value().getNumResults()))) {
|
|
||||||
if (ShapedType::isDynamic(dim.value())) {
|
|
||||||
if (isExpandingReshape && dynamicShape) {
|
|
||||||
return op->emitOpError("invalid to have a single dimension (")
|
|
||||||
<< map.index() << ") expanded into multiple dynamic dims ("
|
|
||||||
<< expandedDimStart + dynamicShape.getValue() << ","
|
|
||||||
<< expandedDimStart + dim.index() << ")";
|
|
||||||
}
|
|
||||||
dynamicShape = dim.index();
|
|
||||||
} else {
|
|
||||||
linearizedStaticShape *= dim.value();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (dynamicShape) {
|
|
||||||
if (!ShapedType::isDynamic(collapsedShape[map.index()])) {
|
|
||||||
return op->emitOpError("expected dimension ")
|
|
||||||
<< map.index()
|
|
||||||
<< " of collapsed type to be dynamic since one or more of the "
|
|
||||||
"corresponding dimensions in the expanded type is dynamic";
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if (collapsedShape[map.index()] != linearizedStaticShape) {
|
|
||||||
return op->emitOpError("expected dimension ")
|
|
||||||
<< map.index() << " of collapsed type to be static value of "
|
|
||||||
<< linearizedStaticShape << " ";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
expandedDimStart += map.value().getNumResults();
|
|
||||||
}
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Common verifier for reshape-like types. Fills `expandedType` and
|
|
||||||
// `collapsedType` with the proper `src` or `result` type.
|
|
||||||
template <typename Op, typename T,
|
|
||||||
bool isExpansion = std::is_same<Op, TensorExpandShapeOp>::value ||
|
|
||||||
std::is_same<Op, ExpandShapeOp>::value>
|
|
||||||
static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType,
|
|
||||||
T collapsedType) {
|
|
||||||
unsigned expandedRank = expandedType.getRank();
|
|
||||||
unsigned collapsedRank = collapsedType.getRank();
|
|
||||||
if (expandedRank < collapsedRank)
|
|
||||||
return op.emitOpError("expected the type ")
|
|
||||||
<< expandedType
|
|
||||||
<< " to have higher rank than the type = " << collapsedType;
|
|
||||||
if (expandedRank == 0)
|
|
||||||
return op.emitOpError("expected non-zero memref ranks");
|
|
||||||
if (expandedRank == collapsedRank)
|
|
||||||
return op.emitOpError("expected to collapse or expand dims");
|
|
||||||
|
|
||||||
if (collapsedRank == 0) {
|
|
||||||
// If collapsed rank is 0, then expanded type must be static shaped and of
|
|
||||||
// sizes 1.
|
|
||||||
if (llvm::any_of(expandedType.getShape(),
|
|
||||||
[](int64_t dim) -> bool { return dim != 1; }))
|
|
||||||
return op.emitOpError("invalid to reshape tensor/memref with non-unit "
|
|
||||||
"extent dimensions to zero-rank tensor/memref");
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
if (collapsedRank != op.reassociation().size())
|
|
||||||
return op.emitOpError("expected rank of the collapsed type(")
|
|
||||||
<< collapsedRank << ") to be the number of reassociation maps("
|
|
||||||
<< op.reassociation().size() << ")";
|
|
||||||
auto maps = op.getReassociationMaps();
|
|
||||||
for (auto it : llvm::enumerate(maps))
|
|
||||||
if (it.value().getNumDims() != expandedRank)
|
|
||||||
return op.emitOpError("expected reassociation map #")
|
|
||||||
<< it.index() << " of same rank as expanded memref("
|
|
||||||
<< expandedRank << "), but got " << it.value().getNumDims();
|
|
||||||
int invalidIdx = 0;
|
|
||||||
if (!isReassociationValid(maps, &invalidIdx))
|
|
||||||
return op.emitOpError("expected reassociation map #")
|
|
||||||
<< invalidIdx << " to be valid and contiguous";
|
|
||||||
return verifyReshapeLikeShapes(op, collapsedType, expandedType, isExpansion);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename TensorReshapeOp>
|
|
||||||
static LogicalResult verifyReshapeOp(TensorReshapeOp op,
|
|
||||||
MemRefType expandedType,
|
|
||||||
MemRefType collapsedType) {
|
MemRefType collapsedType) {
|
||||||
if (failed(verifyReshapeLikeTypes(op, expandedType, collapsedType)))
|
if (failed(
|
||||||
|
verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
|
||||||
return failure();
|
return failure();
|
||||||
auto maps = op.getReassociationMaps();
|
auto maps = op.getReassociationMaps();
|
||||||
MemRefType expectedType = computeReshapeCollapsedType(expandedType, maps);
|
MemRefType expectedType = computeReshapeCollapsedType(expandedType, maps);
|
||||||
|
@ -1923,11 +1511,14 @@ void mlir::linalg::TensorExpandShapeOp::build(
|
||||||
getReassociationIndicesAttribute(b, reassociation));
|
getReassociationIndicesAttribute(b, reassociation));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename TensorReshapeOp>
|
template <typename TensorReshapeOp,
|
||||||
|
bool isExpansion =
|
||||||
|
std::is_same<TensorReshapeOp, TensorExpandShapeOp>::value>
|
||||||
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
|
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
|
||||||
RankedTensorType expandedType,
|
RankedTensorType expandedType,
|
||||||
RankedTensorType collapsedType) {
|
RankedTensorType collapsedType) {
|
||||||
if (failed(verifyReshapeLikeTypes(op, expandedType, collapsedType)))
|
if (failed(
|
||||||
|
verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto maps = op.getReassociationMaps();
|
auto maps = op.getReassociationMaps();
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
add_mlir_library(MLIRDialectUtils
|
add_mlir_library(MLIRDialectUtils
|
||||||
|
ReshapeOpsUtils.cpp
|
||||||
StructuredOpsUtils.cpp
|
StructuredOpsUtils.cpp
|
||||||
StaticValueUtils.cpp
|
StaticValueUtils.cpp
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,209 @@
|
||||||
|
//===- ReshapeOpsUtils.cpp - Utilities used by structured ops -------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
|
||||||
|
|
||||||
|
#include "mlir/IR/AffineMap.h"
|
||||||
|
#include "mlir/IR/Builders.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
constexpr StringRef mlir::getReassociationAttrName() { return "reassociation"; }
|
||||||
|
|
||||||
|
Optional<SmallVector<ReassociationIndices>>
|
||||||
|
mlir::getReassociationIndicesForReshape(ShapedType sourceType,
|
||||||
|
ShapedType targetType) {
|
||||||
|
// Make the sourceType greater rank than the targetType. If they are same
|
||||||
|
// rank, then its an unsupported reshape op.
|
||||||
|
if (sourceType.getRank() == targetType.getRank())
|
||||||
|
return llvm::None;
|
||||||
|
if (sourceType.getRank() < targetType.getRank())
|
||||||
|
std::swap(sourceType, targetType);
|
||||||
|
|
||||||
|
ArrayRef<int64_t> sourceShape = sourceType.getShape();
|
||||||
|
ArrayRef<int64_t> targetShape = targetType.getShape();
|
||||||
|
unsigned sourceDim = 0;
|
||||||
|
SmallVector<ReassociationIndices> reassociationMap;
|
||||||
|
reassociationMap.reserve(targetType.getRank());
|
||||||
|
|
||||||
|
ReassociationIndices currIndices;
|
||||||
|
int64_t prodOfCollapsedDims = 1;
|
||||||
|
while (sourceDim < sourceShape.size()) {
|
||||||
|
unsigned targetDim = reassociationMap.size();
|
||||||
|
|
||||||
|
// If all the dimensions of the targetShape are exhausted, then the
|
||||||
|
// remaining dims in the source shape must be all 1s. So for such cases, set
|
||||||
|
// 1 as the target shape. The actual reassociation indices will be handled
|
||||||
|
// later.
|
||||||
|
int64_t currTargetShape =
|
||||||
|
(targetDim < targetType.getRank() ? targetShape[targetDim] : 1);
|
||||||
|
while (sourceShape[sourceDim] != ShapedType::kDynamicSize &&
|
||||||
|
prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape &&
|
||||||
|
sourceDim < sourceShape.size()) {
|
||||||
|
prodOfCollapsedDims *= sourceShape[sourceDim];
|
||||||
|
currIndices.push_back(sourceDim++);
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the current expanded dimension is dynamic, then the collapsed
|
||||||
|
// dimensions should also be dynamic and product of all previous unprocessed
|
||||||
|
// dimensions of the expanded shape should be 1.
|
||||||
|
if (sourceShape[sourceDim] == ShapedType::kDynamicSize &&
|
||||||
|
(currTargetShape != ShapedType::kDynamicSize ||
|
||||||
|
prodOfCollapsedDims != 1))
|
||||||
|
return llvm::None;
|
||||||
|
|
||||||
|
// If the collapsed dim is dynamic, the current expanded dim should also
|
||||||
|
// be dynamic.
|
||||||
|
if (currTargetShape == ShapedType::kDynamicSize &&
|
||||||
|
sourceShape[sourceDim] != ShapedType::kDynamicSize)
|
||||||
|
return llvm::None;
|
||||||
|
|
||||||
|
// For static shapes, if the product of dimensions of the expanded shape
|
||||||
|
// should match the collapsed dimension shape.
|
||||||
|
if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
|
||||||
|
return llvm::None;
|
||||||
|
|
||||||
|
currIndices.push_back(sourceDim++);
|
||||||
|
// If the reassociation is empty but the currIndices is not, this by
|
||||||
|
// definition is folding unit-dimensions with the result being scalar type.
|
||||||
|
// So only append the `currIndices` if reassociation map is not empty.
|
||||||
|
if (targetDim == targetShape.size()) {
|
||||||
|
if (!reassociationMap.empty() && !currIndices.empty())
|
||||||
|
reassociationMap.back().append(currIndices.begin(), currIndices.end());
|
||||||
|
// Break out of the loops. We should be done here.
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
reassociationMap.emplace_back(ReassociationIndices{});
|
||||||
|
std::swap(reassociationMap.back(), currIndices);
|
||||||
|
prodOfCollapsedDims = 1;
|
||||||
|
}
|
||||||
|
// All the dimensions in the two shapes must have been processed.
|
||||||
|
if (reassociationMap.size() != targetShape.size() ||
|
||||||
|
sourceDim != sourceShape.size())
|
||||||
|
return llvm::None;
|
||||||
|
return reassociationMap;
|
||||||
|
}
|
||||||
|
|
||||||
|
ParseResult mlir::parseReshapeLikeOp(OpAsmParser &parser,
|
||||||
|
OperationState &result) {
|
||||||
|
// Parse the operand.
|
||||||
|
OpAsmParser::OperandType src;
|
||||||
|
if (parser.parseOperand(src))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Parse reassociation indices.
|
||||||
|
Builder &b = parser.getBuilder();
|
||||||
|
SmallVector<Attribute, 4> reassociation;
|
||||||
|
if (parser.parseLSquare())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
if (succeeded(parser.parseOptionalRSquare()))
|
||||||
|
break;
|
||||||
|
if (parser.parseLSquare())
|
||||||
|
return failure();
|
||||||
|
SmallVector<int64_t> indices;
|
||||||
|
while (true) {
|
||||||
|
int64_t index;
|
||||||
|
if (parser.parseInteger(index))
|
||||||
|
return failure();
|
||||||
|
indices.push_back(index);
|
||||||
|
|
||||||
|
if (succeeded(parser.parseOptionalComma()))
|
||||||
|
continue;
|
||||||
|
if (failed(parser.parseRSquare()))
|
||||||
|
return failure();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
reassociation.push_back(b.getI64ArrayAttr(indices));
|
||||||
|
if (succeeded(parser.parseOptionalComma()))
|
||||||
|
continue;
|
||||||
|
if (failed(parser.parseRSquare()))
|
||||||
|
return failure();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
result.addAttribute(getReassociationAttrName(),
|
||||||
|
b.getArrayAttr(reassociation));
|
||||||
|
|
||||||
|
// Parse optional attributes.
|
||||||
|
parser.parseOptionalAttrDict(result.attributes);
|
||||||
|
|
||||||
|
// Parse types.
|
||||||
|
Type srcType;
|
||||||
|
Type resultType;
|
||||||
|
if (parser.parseColon() || parser.parseType(srcType) ||
|
||||||
|
parser.resolveOperand(src, srcType, result.operands) ||
|
||||||
|
parser.parseKeyword("into") || parser.parseType(resultType))
|
||||||
|
return failure();
|
||||||
|
result.addTypes(resultType);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
Optional<SmallVector<ReassociationIndices>>
|
||||||
|
mlir::collapseReassociationIndices(ArrayRef<AffineMap> mapsProducer,
|
||||||
|
ArrayRef<AffineMap> mapsConsumer,
|
||||||
|
MLIRContext *context) {
|
||||||
|
// Make the producer the larger sized vector. If they are of same size, the
|
||||||
|
// resulting reshape is not a supported reshape op.
|
||||||
|
if (mapsProducer.size() == mapsConsumer.size())
|
||||||
|
return llvm::None;
|
||||||
|
if (mapsProducer.size() < mapsConsumer.size())
|
||||||
|
std::swap(mapsProducer, mapsConsumer);
|
||||||
|
|
||||||
|
// Handle the corner case of the result being a rank 0 shaped type. Return an
|
||||||
|
// empty reassociation.
|
||||||
|
if (mapsConsumer.empty())
|
||||||
|
return SmallVector<ReassociationIndices>{};
|
||||||
|
if (mapsProducer.size() != mapsConsumer[0].getNumDims())
|
||||||
|
return llvm::None;
|
||||||
|
|
||||||
|
unsigned currDim = 0;
|
||||||
|
SmallVector<ReassociationIndices> reassociationMaps;
|
||||||
|
for (AffineMap rhs : mapsConsumer) {
|
||||||
|
ReassociationIndices reassociations;
|
||||||
|
for (AffineExpr rhsExpr : rhs.getResults()) {
|
||||||
|
AffineDimExpr dimExpr = rhsExpr.cast<AffineDimExpr>();
|
||||||
|
for (int i = 0, e = mapsProducer[dimExpr.getPosition()].getNumResults();
|
||||||
|
i < e; ++i)
|
||||||
|
reassociations.push_back(currDim++);
|
||||||
|
}
|
||||||
|
reassociationMaps.push_back(std::move(reassociations));
|
||||||
|
}
|
||||||
|
return reassociationMaps;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool mlir::isReassociationValid(ArrayRef<AffineMap> reassociation,
|
||||||
|
int *invalidIndex) {
|
||||||
|
if (reassociation.empty())
|
||||||
|
return true;
|
||||||
|
unsigned nDims = reassociation[0].getNumDims();
|
||||||
|
unsigned nextExpectedDim = 0;
|
||||||
|
for (auto it : llvm::enumerate(reassociation)) {
|
||||||
|
auto m = it.value();
|
||||||
|
if (m.getNumDims() != nDims || m.getNumSymbols() != 0) {
|
||||||
|
if (invalidIndex)
|
||||||
|
*invalidIndex = it.index();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (auto e : m.getResults()) {
|
||||||
|
auto d = e.dyn_cast<AffineDimExpr>();
|
||||||
|
if (!d || d.getPosition() != nextExpectedDim++) {
|
||||||
|
if (invalidIndex)
|
||||||
|
*invalidIndex = it.index();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (nextExpectedDim != nDims) {
|
||||||
|
if (invalidIndex)
|
||||||
|
*invalidIndex = reassociation.size() - 1;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
Loading…
Reference in New Issue