[mlir][vector] NFC - Refactor and extract a helper StructuredGenerator class

Differential Revision: https://reviews.llvm.org/D111893
This commit is contained in:
Nicolas Vasilache 2021-10-15 15:56:58 +00:00
parent a59c1a2138
commit 60802715d1
2 changed files with 159 additions and 111 deletions

View File

@ -19,11 +19,14 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Location.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/StringRef.h"
namespace mlir {
class PatternRewriter;
/// Tests whether the given maps describe a row major matmul. The test is
/// permutation-invariant. Note that this only checks the affine maps from an
/// operation, so does not perform any checks on the math being performed within
@ -132,6 +135,60 @@ inline StringRef toString(IteratorType t) {
llvm_unreachable("Unsupported IteratorType");
}
/// Helper StructuredGenerator class to manipulate and rewrite ops with
/// `StructuredOpInterface`. This is templated for now because VectorOps do not
/// yet implement the StructuredOpInterface itself.
template <typename StructuredOpInterface>
class StructuredGenerator {
public:
using MapList = ArrayRef<ArrayRef<AffineExpr>>;
struct IteratorType {
IteratorType(StringRef strRef) : strRef(strRef) {}
bool isOfType(Attribute attr) const {
auto sAttr = attr.dyn_cast<StringAttr>();
return sAttr && sAttr.getValue() == strRef;
}
StringRef strRef;
};
struct Par : public IteratorType {
Par() : IteratorType(getParallelIteratorTypeName()) {}
};
struct Red : public IteratorType {
Red() : IteratorType(getReductionIteratorTypeName()) {}
};
struct Win : public IteratorType {
Win() : IteratorType(getWindowIteratorTypeName()) {}
};
StructuredGenerator(PatternRewriter &rewriter, StructuredOpInterface op)
: rewriter(rewriter), ctx(op.getContext()), loc(op.getLoc()),
iterators(op.iterator_types()), maps(op.getIndexingMaps()), op(op) {}
bool iters(ArrayRef<IteratorType> its) {
if (its.size() != iterators.size())
return false;
for (int i = 0, e = its.size(); i != e; ++i) {
if (!its[i].isOfType(iterators[i]))
return false;
}
return true;
}
bool layout(MapList l) {
auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
return maps == infer(l);
}
protected:
PatternRewriter &rewriter;
MLIRContext *ctx;
Location loc;
ArrayAttr iterators;
SmallVector<AffineMap, 4> maps;
Operation *op;
};
} // end namespace mlir
#endif // MLIR_UTILS_STRUCTUREDOPSUTILS_H

View File

@ -1252,35 +1252,22 @@ struct Red : public IteratorType {
Red() : IteratorType(getReductionIteratorTypeName()) {}
};
// Unroll outer-products along reduction.
struct UnrolledOuterProductEmitter {
using MapList = ArrayRef<ArrayRef<AffineExpr>>;
/// Generate a vector implementation for matmat, matvec and tmatvec.
/// This unrolls outer-products along the reduction dimension.
struct UnrolledOuterProductGenerator
: public StructuredGenerator<vector::ContractionOp> {
UnrolledOuterProductEmitter(PatternRewriter &rewriter,
vector::ContractionOp op)
: rewriter(rewriter), loc(op.getLoc()), kind(op.kind()),
iterators(op.iterator_types()), maps(op.getIndexingMaps()), op(op) {}
UnrolledOuterProductGenerator(PatternRewriter &rewriter,
vector::ContractionOp op)
: StructuredGenerator<vector::ContractionOp>(rewriter, op),
kind(op.kind()), lhs(op.lhs()), rhs(op.rhs()), res(op.acc()),
lhsType(op.getLhsType()) {}
Value t(Value v) {
static constexpr std::array<int64_t, 2> perm = {1, 0};
return rewriter.create<vector::TransposeOp>(loc, v, perm);
}
bool iters(ArrayRef<IteratorType> its) {
if (its.size() != iterators.size())
return false;
for (int i = 0, e = its.size(); i != e; ++i) {
if (!its[i].isOfType(iterators[i]))
return false;
}
return true;
}
bool layout(MapList l) {
auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
return maps == infer(l);
}
LogicalResult outer_prod(Value lhs, Value rhs, Value res, int reductionSize) {
assert(reductionSize > 0);
for (int64_t k = 0; k < reductionSize; ++k) {
@ -1293,12 +1280,93 @@ struct UnrolledOuterProductEmitter {
return success();
}
PatternRewriter &rewriter;
Location loc;
/// Two outer parallel, one inner reduction (matmat flavor).
LogicalResult matmat() {
if (!iters({Par(), Par(), Red()}))
return failure();
// Set up the parallel/reduction structure in the right form.
AffineExpr m, n, k;
bindDims(rewriter.getContext(), m, n, k);
// Classical row-major matmul: Just permute the lhs.
if (layout({{m, k}, {k, n}, {m, n}}))
return outer_prod(t(lhs), rhs, res, lhsType.getDimSize(1));
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
if (layout({{m, k}, {n, k}, {m, n}})) {
Value tlhs = t(lhs);
return outer_prod(tlhs, t(rhs), res, lhsType.getDimSize(1));
}
// No need to permute anything.
if (layout({{k, m}, {k, n}, {m, n}}))
return outer_prod(lhs, rhs, res, lhsType.getDimSize(0));
// Just permute the rhs.
if (layout({{k, m}, {n, k}, {m, n}}))
return outer_prod(lhs, t(rhs), res, lhsType.getDimSize(0));
// Transposed output: swap RHS and LHS.
// Classical row-major matmul: permute the lhs.
if (layout({{m, k}, {k, n}, {n, m}}))
return outer_prod(rhs, t(lhs), res, lhsType.getDimSize(1));
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
if (layout({{m, k}, {n, k}, {n, m}})) {
Value trhs = t(rhs);
return outer_prod(trhs, t(lhs), res, lhsType.getDimSize(1));
}
if (layout({{k, m}, {k, n}, {n, m}}))
return outer_prod(rhs, lhs, res, lhsType.getDimSize(0));
if (layout({{k, m}, {n, k}, {n, m}}))
return outer_prod(t(rhs), lhs, res, lhsType.getDimSize(0));
return failure();
}
/// One outer parallel, one inner reduction (matvec flavor)
LogicalResult matvec() {
if (!iters({Par(), Red()}))
return failure();
AffineExpr m, k;
bindDims(rewriter.getContext(), m, k);
// Case mat-vec: transpose.
if (layout({{m, k}, {k}, {m}}))
return outer_prod(t(lhs), rhs, res, lhsType.getDimSize(1));
// Case mat-trans-vec: ready to go.
if (layout({{k, m}, {k}, {m}}))
return outer_prod(lhs, rhs, res, lhsType.getDimSize(0));
// Case vec-mat: swap and transpose.
if (layout({{k}, {m, k}, {m}}))
return outer_prod(t(rhs), lhs, res, lhsType.getDimSize(0));
// Case vec-mat-trans: swap and ready to go.
if (layout({{k}, {k, m}, {m}}))
return outer_prod(rhs, lhs, res, lhsType.getDimSize(0));
return failure();
}
//
// One outer reduction, one inner parallel (tmatvec flavor)
//
LogicalResult tmatvec() {
if (!iters({Red(), Par()}))
return failure();
AffineExpr k, m;
bindDims(rewriter.getContext(), k, m);
// Case mat-vec: transpose.
if (layout({{m, k}, {k}, {m}}))
return outer_prod(t(lhs), rhs, res, lhsType.getDimSize(1));
// Case mat-trans-vec: ready to go.
if (layout({{k, m}, {k}, {m}}))
return outer_prod(lhs, rhs, res, lhsType.getDimSize(0));
// Case vec-mat: swap and transpose.
if (layout({{k}, {m, k}, {m}}))
return outer_prod(t(rhs), lhs, res, lhsType.getDimSize(0));
// Case vec-mat-trans: swap and ready to go.
if (layout({{k}, {k, m}, {m}}))
return outer_prod(rhs, lhs, res, lhsType.getDimSize(0));
return failure();
}
private:
vector::CombiningKind kind;
ArrayAttr iterators;
SmallVector<AffineMap, 4> maps;
Operation *op;
Value lhs, rhs, res;
VectorType lhsType;
};
} // namespace
@ -1330,90 +1398,13 @@ LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
if (failed(filter(op)))
return failure();
VectorType lhsType = op.getLhsType();
Value lhs = op.lhs(), rhs = op.rhs(), res = op.acc();
//
// Two outer parallel, one inner reduction (matmat flavor).
//
UnrolledOuterProductEmitter e(rewriter, op);
if (e.iters({Par(), Par(), Red()})) {
// Set up the parallel/reduction structure in right form.
AffineExpr m, n, k;
bindDims(rewriter.getContext(), m, n, k);
// Classical row-major matmul: Just permute the lhs.
if (e.layout({{m, k}, {k, n}, {m, n}}))
return e.outer_prod(e.t(lhs), rhs, res, lhsType.getDimSize(1));
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
if (e.layout({{m, k}, {n, k}, {m, n}})) {
Value tlhs = e.t(lhs);
return e.outer_prod(tlhs, e.t(rhs), res, lhsType.getDimSize(1));
}
// No need to permute anything.
if (e.layout({{k, m}, {k, n}, {m, n}}))
return e.outer_prod(lhs, rhs, res, lhsType.getDimSize(0));
// Just permute the rhs.
if (e.layout({{k, m}, {n, k}, {m, n}}))
return e.outer_prod(lhs, e.t(rhs), res, lhsType.getDimSize(0));
// Transposed output: swap RHS and LHS.
// Classical row-major matmul: permute the lhs.
if (e.layout({{m, k}, {k, n}, {n, m}}))
return e.outer_prod(rhs, e.t(lhs), res, lhsType.getDimSize(1));
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
if (e.layout({{m, k}, {n, k}, {n, m}})) {
Value trhs = e.t(rhs);
return e.outer_prod(trhs, e.t(lhs), res, lhsType.getDimSize(1));
}
if (e.layout({{k, m}, {k, n}, {n, m}}))
return e.outer_prod(rhs, lhs, res, lhsType.getDimSize(0));
if (e.layout({{k, m}, {n, k}, {n, m}}))
return e.outer_prod(e.t(rhs), lhs, res, lhsType.getDimSize(0));
return failure();
}
//
// One outer parallel, one inner reduction (matvec flavor)
//
if (e.iters({Par(), Red()})) {
AffineExpr m, k;
bindDims(rewriter.getContext(), m, k);
// Case mat-vec: transpose.
if (e.layout({{m, k}, {k}, {m}}))
return e.outer_prod(e.t(lhs), rhs, res, lhsType.getDimSize(1));
// Case mat-trans-vec: ready to go.
if (e.layout({{k, m}, {k}, {m}}))
return e.outer_prod(lhs, rhs, res, lhsType.getDimSize(0));
// Case vec-mat: swap and transpose.
if (e.layout({{k}, {m, k}, {m}}))
return e.outer_prod(e.t(rhs), lhs, res, lhsType.getDimSize(0));
// Case vec-mat-trans: swap and ready to go.
if (e.layout({{k}, {k, m}, {m}}))
return e.outer_prod(rhs, lhs, res, lhsType.getDimSize(0));
return failure();
}
//
// One outer reduction, one inner parallel (tmatvec flavor)
//
if (e.iters({Red(), Par()})) {
AffineExpr k, m;
bindDims(rewriter.getContext(), k, m);
// Case mat-vec: transpose.
if (e.layout({{m, k}, {k}, {m}}))
return e.outer_prod(e.t(lhs), rhs, res, lhsType.getDimSize(1));
// Case mat-trans-vec: ready to go.
if (e.layout({{k, m}, {k}, {m}}))
return e.outer_prod(lhs, rhs, res, lhsType.getDimSize(0));
// Case vec-mat: swap and transpose.
if (e.layout({{k}, {m, k}, {m}}))
return e.outer_prod(e.t(rhs), lhs, res, lhsType.getDimSize(0));
// Case vec-mat-trans: swap and ready to go.
if (e.layout({{k}, {k, m}, {m}}))
return e.outer_prod(rhs, lhs, res, lhsType.getDimSize(0));
return failure();
}
UnrolledOuterProductGenerator e(rewriter, op);
if (succeeded(e.matmat()))
return success();
if (succeeded(e.matvec()))
return success();
if (succeeded(e.tmatvec()))
return success();
return failure();
}