forked from OSchip/llvm-project
[mlir][Vector] Add ExtractOp folding when fed by a TransposeOp
TransposeOp are often followed by ExtractOp. In certain cases however, it is unnecessary (and even detrimental) to lower a TransposeOp to either a flat transpose (llvm.matrix intrinsics) or to unrolled scalar insert / extract chains. Providing foldings of ExtractOp mitigates some of the unnecessary complexity. Differential revision: https://reviews.llvm.org/D83487
This commit is contained in:
parent
9fd4b5faac
commit
a490d387e6
|
@ -170,6 +170,10 @@ public:
|
|||
/// `(d0)[s0, s1, s2] -> (d0 + s1 + s2 + 1, d0 - s0 - s2 - 1)`
|
||||
AffineMap compose(AffineMap map);
|
||||
|
||||
/// Applies composition by the dims of `this` to the integer `values` and
|
||||
/// returns the resulting values. `this` must be symbol-less.
|
||||
SmallVector<int64_t, 4> compose(ArrayRef<int64_t> values);
|
||||
|
||||
/// Returns true if the AffineMap represents a subset (i.e. a projection) of a
|
||||
/// symbol-less permutation map.
|
||||
bool isProjectedPermutation();
|
||||
|
@ -180,6 +184,11 @@ public:
|
|||
/// Returns the map consisting of the `resultPos` subset.
|
||||
AffineMap getSubMap(ArrayRef<unsigned> resultPos);
|
||||
|
||||
/// Returns the map consisting of the most major `numResults` results.
|
||||
/// Returns the null AffineMap if `numResults` == 0.
|
||||
/// Returns `*this` if `numResults` >= `this->getNumResults()`.
|
||||
AffineMap getMajorSubMap(unsigned numResults);
|
||||
|
||||
/// Returns the map consisting of the most minor `numResults` results.
|
||||
/// Returns the null AffineMap if `numResults` == 0.
|
||||
/// Returns `*this` if `numResults` >= `this->getNumResults()`.
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
|
@ -602,6 +603,63 @@ static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
|
|||
return success();
|
||||
}
|
||||
|
||||
/// Fold the result of an ExtractOp in place when it comes from a TransposeOp.
|
||||
static LogicalResult foldExtractOpFromTranspose(ExtractOp extractOp) {
|
||||
auto transposeOp = extractOp.vector().getDefiningOp<TransposeOp>();
|
||||
if (!transposeOp)
|
||||
return failure();
|
||||
|
||||
auto permutation = extractVector<unsigned>(transposeOp.transp());
|
||||
auto extractedPos = extractVector<int64_t>(extractOp.position());
|
||||
|
||||
// If transposition permutation is larger than the ExtractOp, all minor
|
||||
// dimensions must be an identity for folding to occur. If not, individual
|
||||
// elements within the extracted value are transposed and this is not just a
|
||||
// simple folding.
|
||||
unsigned minorRank = permutation.size() - extractedPos.size();
|
||||
MLIRContext *ctx = extractOp.getContext();
|
||||
AffineMap permutationMap = AffineMap::getPermutationMap(permutation, ctx);
|
||||
AffineMap minorMap = permutationMap.getMinorSubMap(minorRank);
|
||||
if (minorMap && !AffineMap::isMinorIdentity(minorMap))
|
||||
return failure();
|
||||
|
||||
// %1 = transpose %0[x, y, z] : vector<axbxcxf32>
|
||||
// %2 = extract %1[u, v] : vector<..xf32>
|
||||
// may turn into:
|
||||
// %2 = extract %0[w, x] : vector<..xf32>
|
||||
// iff z == 2 and [w, x] = [x, y]^-1 o [u, v] here o denotes composition and
|
||||
// -1 denotes the inverse.
|
||||
permutationMap = permutationMap.getMajorSubMap(extractedPos.size());
|
||||
// The major submap has fewer results but the same number of dims. To compose
|
||||
// cleanly, we need to drop dims to form a "square matrix". This is possible
|
||||
// because:
|
||||
// (a) this is a permutation map and
|
||||
// (b) the minor map has already been checked to be identity.
|
||||
// Therefore, the major map cannot contain dims of position greater or equal
|
||||
// than the number of results.
|
||||
assert(llvm::all_of(permutationMap.getResults(),
|
||||
[&](AffineExpr e) {
|
||||
auto dim = e.dyn_cast<AffineDimExpr>();
|
||||
return dim && dim.getPosition() <
|
||||
permutationMap.getNumResults();
|
||||
}) &&
|
||||
"Unexpected map results depend on higher rank positions");
|
||||
// Project on the first domain dimensions to allow composition.
|
||||
permutationMap = AffineMap::get(permutationMap.getNumResults(), 0,
|
||||
permutationMap.getResults(), ctx);
|
||||
|
||||
extractOp.setOperand(transposeOp.vector());
|
||||
// Compose the inverse permutation map with the extractedPos.
|
||||
auto newExtractedPos =
|
||||
inversePermutation(permutationMap).compose(extractedPos);
|
||||
// OpBuilder is only used as a helper to build an I64ArrayAttr.
|
||||
OpBuilder b(extractOp.getContext());
|
||||
extractOp.setAttr(ExtractOp::getPositionAttrName(),
|
||||
b.getI64ArrayAttr(newExtractedPos));
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps. The
|
||||
/// result is always the input to some InsertOp.
|
||||
static Value foldExtractOpFromInsertChainAndTranspose(ExtractOp extractOp) {
|
||||
|
@ -689,6 +747,8 @@ static Value foldExtractOpFromInsertChainAndTranspose(ExtractOp extractOp) {
|
|||
OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
|
||||
if (succeeded(foldExtractOpFromExtractChain(*this)))
|
||||
return getResult();
|
||||
if (succeeded(foldExtractOpFromTranspose(*this)))
|
||||
return getResult();
|
||||
if (auto val = foldExtractOpFromInsertChainAndTranspose(*this))
|
||||
return val;
|
||||
return OpFoldResult();
|
||||
|
|
|
@ -330,6 +330,21 @@ AffineMap AffineMap::compose(AffineMap map) {
|
|||
return AffineMap::get(numDims, numSymbols, exprs, map.getContext());
|
||||
}
|
||||
|
||||
SmallVector<int64_t, 4> AffineMap::compose(ArrayRef<int64_t> values) {
|
||||
assert(getNumSymbols() == 0 && "Expected symbol-less map");
|
||||
SmallVector<AffineExpr, 4> exprs;
|
||||
exprs.reserve(values.size());
|
||||
MLIRContext *ctx = getContext();
|
||||
for (auto v : values)
|
||||
exprs.push_back(getAffineConstantExpr(v, ctx));
|
||||
auto resMap = compose(AffineMap::get(0, 0, exprs, ctx));
|
||||
SmallVector<int64_t, 4> res;
|
||||
res.reserve(resMap.getNumResults());
|
||||
for (auto e : resMap.getResults())
|
||||
res.push_back(e.cast<AffineConstantExpr>().getValue());
|
||||
return res;
|
||||
}
|
||||
|
||||
bool AffineMap::isProjectedPermutation() {
|
||||
if (getNumSymbols() > 0)
|
||||
return false;
|
||||
|
@ -360,6 +375,14 @@ AffineMap AffineMap::getSubMap(ArrayRef<unsigned> resultPos) {
|
|||
return AffineMap::get(getNumDims(), getNumSymbols(), exprs, getContext());
|
||||
}
|
||||
|
||||
AffineMap AffineMap::getMajorSubMap(unsigned numResults) {
|
||||
if (numResults == 0)
|
||||
return AffineMap();
|
||||
if (numResults > getNumResults())
|
||||
return *this;
|
||||
return getSubMap(llvm::to_vector<4>(llvm::seq<unsigned>(0, numResults)));
|
||||
}
|
||||
|
||||
AffineMap AffineMap::getMinorSubMap(unsigned numResults) {
|
||||
if (numResults == 0)
|
||||
return AffineMap();
|
||||
|
|
|
@ -300,13 +300,48 @@ func @insert_extract_transpose_3d_2d(
|
|||
|
||||
// CHECK-LABEL: fold_extracts
|
||||
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<3x4x5x6xf32>
|
||||
// CHECK-NEXT: vector.extract %[[A]][0, 1, 2, 3] : vector<3x4x5x6xf32>
|
||||
// CHECK-NEXT: vector.extract %[[A]][0] : vector<3x4x5x6xf32>
|
||||
// CHECK-NEXT: return
|
||||
func @fold_extracts(%a : vector<3x4x5x6xf32>) -> (f32, vector<4x5x6xf32>) {
|
||||
%b = vector.extract %a[0] : vector<3x4x5x6xf32>
|
||||
%c = vector.extract %b[1, 2] : vector<4x5x6xf32>
|
||||
// CHECK-NEXT: vector.extract %[[A]][0, 1, 2, 3] : vector<3x4x5x6xf32>
|
||||
%d = vector.extract %c[3] : vector<6xf32>
|
||||
|
||||
// CHECK-NEXT: vector.extract %[[A]][0] : vector<3x4x5x6xf32>
|
||||
%e = vector.extract %a[0] : vector<3x4x5x6xf32>
|
||||
|
||||
// CHECK-NEXT: return
|
||||
return %d, %e : f32, vector<4x5x6xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: fold_extract_transpose
|
||||
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<3x4x5x6xf32>
|
||||
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x6x5x6xf32>
|
||||
func @fold_extract_transpose(
|
||||
%a : vector<3x4x5x6xf32>, %b : vector<3x6x5x6xf32>) -> (
|
||||
vector<6xf32>, vector<6xf32>, vector<6xf32>) {
|
||||
// [3] is a proper most minor identity map in transpose.
|
||||
// Permutation is a self inverse and we have.
|
||||
// [0, 2, 1] ^ -1 o [0, 1, 2] = [0, 2, 1] o [0, 1, 2]
|
||||
// = [0, 2, 1]
|
||||
// CHECK-NEXT: vector.extract %[[A]][0, 2, 1] : vector<3x4x5x6xf32>
|
||||
%0 = vector.transpose %a, [0, 2, 1, 3] : vector<3x4x5x6xf32> to vector<3x5x4x6xf32>
|
||||
%1 = vector.extract %0[0, 1, 2] : vector<3x5x4x6xf32>
|
||||
|
||||
// [3] is a proper most minor identity map in transpose.
|
||||
// Permutation is a not self inverse and we have.
|
||||
// [1, 2, 0] ^ -1 o [0, 1, 2] = [2, 0, 1] o [0, 1, 2]
|
||||
// = [2, 0, 1]
|
||||
// CHECK-NEXT: vector.extract %[[A]][2, 0, 1] : vector<3x4x5x6xf32>
|
||||
%2 = vector.transpose %a, [1, 2, 0, 3] : vector<3x4x5x6xf32> to vector<4x5x3x6xf32>
|
||||
%3 = vector.extract %2[0, 1, 2] : vector<4x5x3x6xf32>
|
||||
|
||||
// Not a minor identity map so intra-vector level has been permuted
|
||||
// CHECK-NEXT: vector.transpose %[[B]], [0, 2, 3, 1]
|
||||
// CHECK-NEXT: vector.extract %{{.*}}[0, 1, 2]
|
||||
%4 = vector.transpose %b, [0, 2, 3, 1] : vector<3x6x5x6xf32> to vector<3x5x6x6xf32>
|
||||
%5 = vector.extract %4[0, 1, 2] : vector<3x5x6x6xf32>
|
||||
|
||||
return %1, %3, %5 : vector<6xf32>, vector<6xf32>, vector<6xf32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue