forked from OSchip/llvm-project
[mlir][vector] NFC - Refactor and extract a helper StructuredGenerator class
Differential Revision: https://reviews.llvm.org/D111893
This commit is contained in:
parent
a59c1a2138
commit
60802715d1
|
@ -19,11 +19,14 @@
|
|||
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class PatternRewriter;
|
||||
|
||||
/// Tests whether the given maps describe a row major matmul. The test is
|
||||
/// permutation-invariant. Note that this only checks the affine maps from an
|
||||
/// operation, so does not perform any checks on the math being performed within
|
||||
|
@ -132,6 +135,60 @@ inline StringRef toString(IteratorType t) {
|
|||
llvm_unreachable("Unsupported IteratorType");
|
||||
}
|
||||
|
||||
/// Helper StructuredGenerator class to manipulate and rewrite ops with
|
||||
/// `StructuredOpInterface`. This is templated for now because VectorOps do not
|
||||
/// yet implement the StructuredOpInterface itself.
|
||||
template <typename StructuredOpInterface>
|
||||
class StructuredGenerator {
|
||||
public:
|
||||
using MapList = ArrayRef<ArrayRef<AffineExpr>>;
|
||||
|
||||
struct IteratorType {
|
||||
IteratorType(StringRef strRef) : strRef(strRef) {}
|
||||
bool isOfType(Attribute attr) const {
|
||||
auto sAttr = attr.dyn_cast<StringAttr>();
|
||||
return sAttr && sAttr.getValue() == strRef;
|
||||
}
|
||||
StringRef strRef;
|
||||
};
|
||||
struct Par : public IteratorType {
|
||||
Par() : IteratorType(getParallelIteratorTypeName()) {}
|
||||
};
|
||||
struct Red : public IteratorType {
|
||||
Red() : IteratorType(getReductionIteratorTypeName()) {}
|
||||
};
|
||||
struct Win : public IteratorType {
|
||||
Win() : IteratorType(getWindowIteratorTypeName()) {}
|
||||
};
|
||||
|
||||
StructuredGenerator(PatternRewriter &rewriter, StructuredOpInterface op)
|
||||
: rewriter(rewriter), ctx(op.getContext()), loc(op.getLoc()),
|
||||
iterators(op.iterator_types()), maps(op.getIndexingMaps()), op(op) {}
|
||||
|
||||
bool iters(ArrayRef<IteratorType> its) {
|
||||
if (its.size() != iterators.size())
|
||||
return false;
|
||||
for (int i = 0, e = its.size(); i != e; ++i) {
|
||||
if (!its[i].isOfType(iterators[i]))
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool layout(MapList l) {
|
||||
auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
|
||||
return maps == infer(l);
|
||||
}
|
||||
|
||||
protected:
|
||||
PatternRewriter &rewriter;
|
||||
MLIRContext *ctx;
|
||||
Location loc;
|
||||
ArrayAttr iterators;
|
||||
SmallVector<AffineMap, 4> maps;
|
||||
Operation *op;
|
||||
};
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_UTILS_STRUCTUREDOPSUTILS_H
|
||||
|
|
|
@ -1252,35 +1252,22 @@ struct Red : public IteratorType {
|
|||
Red() : IteratorType(getReductionIteratorTypeName()) {}
|
||||
};
|
||||
|
||||
// Unroll outer-products along reduction.
|
||||
struct UnrolledOuterProductEmitter {
|
||||
using MapList = ArrayRef<ArrayRef<AffineExpr>>;
|
||||
/// Generate a vector implementation for matmat, matvec and tmatvec.
|
||||
/// This unrolls outer-products along the reduction dimension.
|
||||
struct UnrolledOuterProductGenerator
|
||||
: public StructuredGenerator<vector::ContractionOp> {
|
||||
|
||||
UnrolledOuterProductEmitter(PatternRewriter &rewriter,
|
||||
vector::ContractionOp op)
|
||||
: rewriter(rewriter), loc(op.getLoc()), kind(op.kind()),
|
||||
iterators(op.iterator_types()), maps(op.getIndexingMaps()), op(op) {}
|
||||
UnrolledOuterProductGenerator(PatternRewriter &rewriter,
|
||||
vector::ContractionOp op)
|
||||
: StructuredGenerator<vector::ContractionOp>(rewriter, op),
|
||||
kind(op.kind()), lhs(op.lhs()), rhs(op.rhs()), res(op.acc()),
|
||||
lhsType(op.getLhsType()) {}
|
||||
|
||||
Value t(Value v) {
|
||||
static constexpr std::array<int64_t, 2> perm = {1, 0};
|
||||
return rewriter.create<vector::TransposeOp>(loc, v, perm);
|
||||
}
|
||||
|
||||
bool iters(ArrayRef<IteratorType> its) {
|
||||
if (its.size() != iterators.size())
|
||||
return false;
|
||||
for (int i = 0, e = its.size(); i != e; ++i) {
|
||||
if (!its[i].isOfType(iterators[i]))
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool layout(MapList l) {
|
||||
auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
|
||||
return maps == infer(l);
|
||||
}
|
||||
|
||||
LogicalResult outer_prod(Value lhs, Value rhs, Value res, int reductionSize) {
|
||||
assert(reductionSize > 0);
|
||||
for (int64_t k = 0; k < reductionSize; ++k) {
|
||||
|
@ -1293,12 +1280,93 @@ struct UnrolledOuterProductEmitter {
|
|||
return success();
|
||||
}
|
||||
|
||||
PatternRewriter &rewriter;
|
||||
Location loc;
|
||||
/// Two outer parallel, one inner reduction (matmat flavor).
|
||||
LogicalResult matmat() {
|
||||
if (!iters({Par(), Par(), Red()}))
|
||||
return failure();
|
||||
// Set up the parallel/reduction structure in the right form.
|
||||
AffineExpr m, n, k;
|
||||
bindDims(rewriter.getContext(), m, n, k);
|
||||
// Classical row-major matmul: Just permute the lhs.
|
||||
if (layout({{m, k}, {k, n}, {m, n}}))
|
||||
return outer_prod(t(lhs), rhs, res, lhsType.getDimSize(1));
|
||||
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
|
||||
if (layout({{m, k}, {n, k}, {m, n}})) {
|
||||
Value tlhs = t(lhs);
|
||||
return outer_prod(tlhs, t(rhs), res, lhsType.getDimSize(1));
|
||||
}
|
||||
// No need to permute anything.
|
||||
if (layout({{k, m}, {k, n}, {m, n}}))
|
||||
return outer_prod(lhs, rhs, res, lhsType.getDimSize(0));
|
||||
// Just permute the rhs.
|
||||
if (layout({{k, m}, {n, k}, {m, n}}))
|
||||
return outer_prod(lhs, t(rhs), res, lhsType.getDimSize(0));
|
||||
// Transposed output: swap RHS and LHS.
|
||||
// Classical row-major matmul: permute the lhs.
|
||||
if (layout({{m, k}, {k, n}, {n, m}}))
|
||||
return outer_prod(rhs, t(lhs), res, lhsType.getDimSize(1));
|
||||
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
|
||||
if (layout({{m, k}, {n, k}, {n, m}})) {
|
||||
Value trhs = t(rhs);
|
||||
return outer_prod(trhs, t(lhs), res, lhsType.getDimSize(1));
|
||||
}
|
||||
if (layout({{k, m}, {k, n}, {n, m}}))
|
||||
return outer_prod(rhs, lhs, res, lhsType.getDimSize(0));
|
||||
if (layout({{k, m}, {n, k}, {n, m}}))
|
||||
return outer_prod(t(rhs), lhs, res, lhsType.getDimSize(0));
|
||||
return failure();
|
||||
}
|
||||
|
||||
/// One outer parallel, one inner reduction (matvec flavor)
|
||||
LogicalResult matvec() {
|
||||
if (!iters({Par(), Red()}))
|
||||
return failure();
|
||||
AffineExpr m, k;
|
||||
bindDims(rewriter.getContext(), m, k);
|
||||
|
||||
// Case mat-vec: transpose.
|
||||
if (layout({{m, k}, {k}, {m}}))
|
||||
return outer_prod(t(lhs), rhs, res, lhsType.getDimSize(1));
|
||||
// Case mat-trans-vec: ready to go.
|
||||
if (layout({{k, m}, {k}, {m}}))
|
||||
return outer_prod(lhs, rhs, res, lhsType.getDimSize(0));
|
||||
// Case vec-mat: swap and transpose.
|
||||
if (layout({{k}, {m, k}, {m}}))
|
||||
return outer_prod(t(rhs), lhs, res, lhsType.getDimSize(0));
|
||||
// Case vec-mat-trans: swap and ready to go.
|
||||
if (layout({{k}, {k, m}, {m}}))
|
||||
return outer_prod(rhs, lhs, res, lhsType.getDimSize(0));
|
||||
return failure();
|
||||
}
|
||||
|
||||
//
|
||||
// One outer reduction, one inner parallel (tmatvec flavor)
|
||||
//
|
||||
LogicalResult tmatvec() {
|
||||
if (!iters({Red(), Par()}))
|
||||
return failure();
|
||||
AffineExpr k, m;
|
||||
bindDims(rewriter.getContext(), k, m);
|
||||
|
||||
// Case mat-vec: transpose.
|
||||
if (layout({{m, k}, {k}, {m}}))
|
||||
return outer_prod(t(lhs), rhs, res, lhsType.getDimSize(1));
|
||||
// Case mat-trans-vec: ready to go.
|
||||
if (layout({{k, m}, {k}, {m}}))
|
||||
return outer_prod(lhs, rhs, res, lhsType.getDimSize(0));
|
||||
// Case vec-mat: swap and transpose.
|
||||
if (layout({{k}, {m, k}, {m}}))
|
||||
return outer_prod(t(rhs), lhs, res, lhsType.getDimSize(0));
|
||||
// Case vec-mat-trans: swap and ready to go.
|
||||
if (layout({{k}, {k, m}, {m}}))
|
||||
return outer_prod(rhs, lhs, res, lhsType.getDimSize(0));
|
||||
return failure();
|
||||
}
|
||||
|
||||
private:
|
||||
vector::CombiningKind kind;
|
||||
ArrayAttr iterators;
|
||||
SmallVector<AffineMap, 4> maps;
|
||||
Operation *op;
|
||||
Value lhs, rhs, res;
|
||||
VectorType lhsType;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
|
@ -1330,90 +1398,13 @@ LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
|
|||
if (failed(filter(op)))
|
||||
return failure();
|
||||
|
||||
VectorType lhsType = op.getLhsType();
|
||||
Value lhs = op.lhs(), rhs = op.rhs(), res = op.acc();
|
||||
|
||||
//
|
||||
// Two outer parallel, one inner reduction (matmat flavor).
|
||||
//
|
||||
UnrolledOuterProductEmitter e(rewriter, op);
|
||||
if (e.iters({Par(), Par(), Red()})) {
|
||||
// Set up the parallel/reduction structure in right form.
|
||||
AffineExpr m, n, k;
|
||||
bindDims(rewriter.getContext(), m, n, k);
|
||||
// Classical row-major matmul: Just permute the lhs.
|
||||
if (e.layout({{m, k}, {k, n}, {m, n}}))
|
||||
return e.outer_prod(e.t(lhs), rhs, res, lhsType.getDimSize(1));
|
||||
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
|
||||
if (e.layout({{m, k}, {n, k}, {m, n}})) {
|
||||
Value tlhs = e.t(lhs);
|
||||
return e.outer_prod(tlhs, e.t(rhs), res, lhsType.getDimSize(1));
|
||||
}
|
||||
// No need to permute anything.
|
||||
if (e.layout({{k, m}, {k, n}, {m, n}}))
|
||||
return e.outer_prod(lhs, rhs, res, lhsType.getDimSize(0));
|
||||
// Just permute the rhs.
|
||||
if (e.layout({{k, m}, {n, k}, {m, n}}))
|
||||
return e.outer_prod(lhs, e.t(rhs), res, lhsType.getDimSize(0));
|
||||
// Transposed output: swap RHS and LHS.
|
||||
// Classical row-major matmul: permute the lhs.
|
||||
if (e.layout({{m, k}, {k, n}, {n, m}}))
|
||||
return e.outer_prod(rhs, e.t(lhs), res, lhsType.getDimSize(1));
|
||||
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
|
||||
if (e.layout({{m, k}, {n, k}, {n, m}})) {
|
||||
Value trhs = e.t(rhs);
|
||||
return e.outer_prod(trhs, e.t(lhs), res, lhsType.getDimSize(1));
|
||||
}
|
||||
if (e.layout({{k, m}, {k, n}, {n, m}}))
|
||||
return e.outer_prod(rhs, lhs, res, lhsType.getDimSize(0));
|
||||
if (e.layout({{k, m}, {n, k}, {n, m}}))
|
||||
return e.outer_prod(e.t(rhs), lhs, res, lhsType.getDimSize(0));
|
||||
return failure();
|
||||
}
|
||||
|
||||
//
|
||||
// One outer parallel, one inner reduction (matvec flavor)
|
||||
//
|
||||
if (e.iters({Par(), Red()})) {
|
||||
AffineExpr m, k;
|
||||
bindDims(rewriter.getContext(), m, k);
|
||||
|
||||
// Case mat-vec: transpose.
|
||||
if (e.layout({{m, k}, {k}, {m}}))
|
||||
return e.outer_prod(e.t(lhs), rhs, res, lhsType.getDimSize(1));
|
||||
// Case mat-trans-vec: ready to go.
|
||||
if (e.layout({{k, m}, {k}, {m}}))
|
||||
return e.outer_prod(lhs, rhs, res, lhsType.getDimSize(0));
|
||||
// Case vec-mat: swap and transpose.
|
||||
if (e.layout({{k}, {m, k}, {m}}))
|
||||
return e.outer_prod(e.t(rhs), lhs, res, lhsType.getDimSize(0));
|
||||
// Case vec-mat-trans: swap and ready to go.
|
||||
if (e.layout({{k}, {k, m}, {m}}))
|
||||
return e.outer_prod(rhs, lhs, res, lhsType.getDimSize(0));
|
||||
return failure();
|
||||
}
|
||||
|
||||
//
|
||||
// One outer reduction, one inner parallel (tmatvec flavor)
|
||||
//
|
||||
if (e.iters({Red(), Par()})) {
|
||||
AffineExpr k, m;
|
||||
bindDims(rewriter.getContext(), k, m);
|
||||
|
||||
// Case mat-vec: transpose.
|
||||
if (e.layout({{m, k}, {k}, {m}}))
|
||||
return e.outer_prod(e.t(lhs), rhs, res, lhsType.getDimSize(1));
|
||||
// Case mat-trans-vec: ready to go.
|
||||
if (e.layout({{k, m}, {k}, {m}}))
|
||||
return e.outer_prod(lhs, rhs, res, lhsType.getDimSize(0));
|
||||
// Case vec-mat: swap and transpose.
|
||||
if (e.layout({{k}, {m, k}, {m}}))
|
||||
return e.outer_prod(e.t(rhs), lhs, res, lhsType.getDimSize(0));
|
||||
// Case vec-mat-trans: swap and ready to go.
|
||||
if (e.layout({{k}, {k, m}, {m}}))
|
||||
return e.outer_prod(rhs, lhs, res, lhsType.getDimSize(0));
|
||||
return failure();
|
||||
}
|
||||
UnrolledOuterProductGenerator e(rewriter, op);
|
||||
if (succeeded(e.matmat()))
|
||||
return success();
|
||||
if (succeeded(e.matvec()))
|
||||
return success();
|
||||
if (succeeded(e.tmatvec()))
|
||||
return success();
|
||||
|
||||
return failure();
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue