[mlir][Vector] Add ExtractOp folding when fed by a TransposeOp

TransposeOp are often followed by ExtractOp.
In certain cases however, it is unnecessary (and even detrimental) to lower a TransposeOp to either a flat transpose (llvm.matrix intrinsics) or to unrolled scalar insert / extract chains.

Providing foldings of ExtractOp mitigates some of the unnecessary complexity.

Differential revision: https://reviews.llvm.org/D83487
This commit is contained in:
Nicolas Vasilache 2020-07-10 09:49:22 -04:00
parent 9fd4b5faac
commit a490d387e6
4 changed files with 130 additions and 3 deletions

View File

@ -170,6 +170,10 @@ public:
/// `(d0)[s0, s1, s2] -> (d0 + s1 + s2 + 1, d0 - s0 - s2 - 1)`
AffineMap compose(AffineMap map);
/// Applies composition by the dims of `this` to the integer `values` and
/// returns the resulting values. `this` must be symbol-less.
SmallVector<int64_t, 4> compose(ArrayRef<int64_t> values);
/// Returns true if the AffineMap represents a subset (i.e. a projection) of a
/// symbol-less permutation map.
bool isProjectedPermutation();
@ -180,6 +184,11 @@ public:
/// Returns the map consisting of the `resultPos` subset.
AffineMap getSubMap(ArrayRef<unsigned> resultPos);
/// Returns the map consisting of the most major `numResults` results.
/// Returns the null AffineMap if `numResults` == 0.
/// Returns `*this` if `numResults` >= `this->getNumResults()`.
AffineMap getMajorSubMap(unsigned numResults);
/// Returns the map consisting of the most minor `numResults` results.
/// Returns the null AffineMap if `numResults` == 0.
/// Returns `*this` if `numResults` >= `this->getNumResults()`.

View File

@ -18,6 +18,7 @@
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
@ -602,6 +603,63 @@ static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
return success();
}
/// Fold the result of an ExtractOp in place when it comes from a TransposeOp.
static LogicalResult foldExtractOpFromTranspose(ExtractOp extractOp) {
auto transposeOp = extractOp.vector().getDefiningOp<TransposeOp>();
if (!transposeOp)
return failure();
auto permutation = extractVector<unsigned>(transposeOp.transp());
auto extractedPos = extractVector<int64_t>(extractOp.position());
// If transposition permutation is larger than the ExtractOp, all minor
// dimensions must be an identity for folding to occur. If not, individual
// elements within the extracted value are transposed and this is not just a
// simple folding.
unsigned minorRank = permutation.size() - extractedPos.size();
MLIRContext *ctx = extractOp.getContext();
AffineMap permutationMap = AffineMap::getPermutationMap(permutation, ctx);
AffineMap minorMap = permutationMap.getMinorSubMap(minorRank);
if (minorMap && !AffineMap::isMinorIdentity(minorMap))
return failure();
// %1 = transpose %0[x, y, z] : vector<axbxcxf32>
// %2 = extract %1[u, v] : vector<..xf32>
// may turn into:
// %2 = extract %0[w, x] : vector<..xf32>
// iff z == 2 and [w, x] = [x, y]^-1 o [u, v] here o denotes composition and
// -1 denotes the inverse.
permutationMap = permutationMap.getMajorSubMap(extractedPos.size());
// The major submap has fewer results but the same number of dims. To compose
// cleanly, we need to drop dims to form a "square matrix". This is possible
// because:
// (a) this is a permutation map and
// (b) the minor map has already been checked to be identity.
// Therefore, the major map cannot contain dims of position greater or equal
// than the number of results.
assert(llvm::all_of(permutationMap.getResults(),
[&](AffineExpr e) {
auto dim = e.dyn_cast<AffineDimExpr>();
return dim && dim.getPosition() <
permutationMap.getNumResults();
}) &&
"Unexpected map results depend on higher rank positions");
// Project on the first domain dimensions to allow composition.
permutationMap = AffineMap::get(permutationMap.getNumResults(), 0,
permutationMap.getResults(), ctx);
extractOp.setOperand(transposeOp.vector());
// Compose the inverse permutation map with the extractedPos.
auto newExtractedPos =
inversePermutation(permutationMap).compose(extractedPos);
// OpBuilder is only used as a helper to build an I64ArrayAttr.
OpBuilder b(extractOp.getContext());
extractOp.setAttr(ExtractOp::getPositionAttrName(),
b.getI64ArrayAttr(newExtractedPos));
return success();
}
/// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps. The
/// result is always the input to some InsertOp.
static Value foldExtractOpFromInsertChainAndTranspose(ExtractOp extractOp) {
@ -689,6 +747,8 @@ static Value foldExtractOpFromInsertChainAndTranspose(ExtractOp extractOp) {
OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
if (succeeded(foldExtractOpFromExtractChain(*this)))
return getResult();
if (succeeded(foldExtractOpFromTranspose(*this)))
return getResult();
if (auto val = foldExtractOpFromInsertChainAndTranspose(*this))
return val;
return OpFoldResult();

View File

@ -330,6 +330,21 @@ AffineMap AffineMap::compose(AffineMap map) {
return AffineMap::get(numDims, numSymbols, exprs, map.getContext());
}
SmallVector<int64_t, 4> AffineMap::compose(ArrayRef<int64_t> values) {
assert(getNumSymbols() == 0 && "Expected symbol-less map");
SmallVector<AffineExpr, 4> exprs;
exprs.reserve(values.size());
MLIRContext *ctx = getContext();
for (auto v : values)
exprs.push_back(getAffineConstantExpr(v, ctx));
auto resMap = compose(AffineMap::get(0, 0, exprs, ctx));
SmallVector<int64_t, 4> res;
res.reserve(resMap.getNumResults());
for (auto e : resMap.getResults())
res.push_back(e.cast<AffineConstantExpr>().getValue());
return res;
}
bool AffineMap::isProjectedPermutation() {
if (getNumSymbols() > 0)
return false;
@ -360,6 +375,14 @@ AffineMap AffineMap::getSubMap(ArrayRef<unsigned> resultPos) {
return AffineMap::get(getNumDims(), getNumSymbols(), exprs, getContext());
}
AffineMap AffineMap::getMajorSubMap(unsigned numResults) {
if (numResults == 0)
return AffineMap();
if (numResults > getNumResults())
return *this;
return getSubMap(llvm::to_vector<4>(llvm::seq<unsigned>(0, numResults)));
}
AffineMap AffineMap::getMinorSubMap(unsigned numResults) {
if (numResults == 0)
return AffineMap();

View File

@ -300,13 +300,48 @@ func @insert_extract_transpose_3d_2d(
// CHECK-LABEL: fold_extracts
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<3x4x5x6xf32>
// CHECK-NEXT: vector.extract %[[A]][0, 1, 2, 3] : vector<3x4x5x6xf32>
// CHECK-NEXT: vector.extract %[[A]][0] : vector<3x4x5x6xf32>
// CHECK-NEXT: return
func @fold_extracts(%a : vector<3x4x5x6xf32>) -> (f32, vector<4x5x6xf32>) {
%b = vector.extract %a[0] : vector<3x4x5x6xf32>
%c = vector.extract %b[1, 2] : vector<4x5x6xf32>
// CHECK-NEXT: vector.extract %[[A]][0, 1, 2, 3] : vector<3x4x5x6xf32>
%d = vector.extract %c[3] : vector<6xf32>
// CHECK-NEXT: vector.extract %[[A]][0] : vector<3x4x5x6xf32>
%e = vector.extract %a[0] : vector<3x4x5x6xf32>
// CHECK-NEXT: return
return %d, %e : f32, vector<4x5x6xf32>
}
// -----
// CHECK-LABEL: fold_extract_transpose
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<3x4x5x6xf32>
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<3x6x5x6xf32>
func @fold_extract_transpose(
%a : vector<3x4x5x6xf32>, %b : vector<3x6x5x6xf32>) -> (
vector<6xf32>, vector<6xf32>, vector<6xf32>) {
// [3] is a proper most minor identity map in transpose.
// Permutation is a self inverse and we have.
// [0, 2, 1] ^ -1 o [0, 1, 2] = [0, 2, 1] o [0, 1, 2]
// = [0, 2, 1]
// CHECK-NEXT: vector.extract %[[A]][0, 2, 1] : vector<3x4x5x6xf32>
%0 = vector.transpose %a, [0, 2, 1, 3] : vector<3x4x5x6xf32> to vector<3x5x4x6xf32>
%1 = vector.extract %0[0, 1, 2] : vector<3x5x4x6xf32>
// [3] is a proper most minor identity map in transpose.
// Permutation is a not self inverse and we have.
// [1, 2, 0] ^ -1 o [0, 1, 2] = [2, 0, 1] o [0, 1, 2]
// = [2, 0, 1]
// CHECK-NEXT: vector.extract %[[A]][2, 0, 1] : vector<3x4x5x6xf32>
%2 = vector.transpose %a, [1, 2, 0, 3] : vector<3x4x5x6xf32> to vector<4x5x3x6xf32>
%3 = vector.extract %2[0, 1, 2] : vector<4x5x3x6xf32>
// Not a minor identity map so intra-vector level has been permuted
// CHECK-NEXT: vector.transpose %[[B]], [0, 2, 3, 1]
// CHECK-NEXT: vector.extract %{{.*}}[0, 1, 2]
%4 = vector.transpose %b, [0, 2, 3, 1] : vector<3x6x5x6xf32> to vector<3x5x6x6xf32>
%5 = vector.extract %4[0, 1, 2] : vector<3x5x6x6xf32>
return %1, %3, %5 : vector<6xf32>, vector<6xf32>, vector<6xf32>
}