forked from OSchip/llvm-project
[mlir][Vector] NFC - Compress vector to outerproduct lowering.
The implementation has become too unwieldy and cognitive overhead wins. Instead compress the implementation in preparation for additional lowering paths. This is a resubmit of https://reviews.llvm.org/D105359 without ordering ambiguities. Differential Revision: https://reviews.llvm.org/D105367
This commit is contained in:
parent
dd1c4bd09d
commit
cb5de7c813
|
@ -1816,6 +1816,72 @@ ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
|
|||
return success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct IteratorType {
|
||||
IteratorType(StringRef strRef) : strRef(strRef) {}
|
||||
bool isOfType(Attribute attr) const {
|
||||
auto sAttr = attr.dyn_cast<StringAttr>();
|
||||
return sAttr && sAttr.getValue() == strRef;
|
||||
}
|
||||
StringRef strRef;
|
||||
};
|
||||
struct Par : public IteratorType {
|
||||
Par() : IteratorType(getParallelIteratorTypeName()) {}
|
||||
};
|
||||
struct Red : public IteratorType {
|
||||
Red() : IteratorType(getReductionIteratorTypeName()) {}
|
||||
};
|
||||
|
||||
// Unroll outer-products along reduction.
|
||||
struct UnrolledOuterProductEmitter {
|
||||
using MapList = ArrayRef<ArrayRef<AffineExpr>>;
|
||||
|
||||
UnrolledOuterProductEmitter(PatternRewriter &rewriter,
|
||||
vector::ContractionOp op)
|
||||
: rewriter(rewriter), loc(op.getLoc()), kind(op.kind()),
|
||||
iterators(op.iterator_types()), maps(op.getIndexingMaps()), op(op) {}
|
||||
|
||||
Value t(Value v) {
|
||||
static constexpr std::array<int64_t, 2> perm = {1, 0};
|
||||
return rewriter.create<vector::TransposeOp>(loc, v, perm);
|
||||
}
|
||||
|
||||
bool iters(ArrayRef<IteratorType> its) {
|
||||
if (its.size() != iterators.size())
|
||||
return false;
|
||||
for (int i = 0, e = its.size(); i != e; ++i) {
|
||||
if (!its[i].isOfType(iterators[i]))
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool layout(MapList l) {
|
||||
auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
|
||||
return maps == infer(l);
|
||||
}
|
||||
|
||||
LogicalResult outer_prod(Value lhs, Value rhs, Value res, int reductionSize) {
|
||||
assert(reductionSize > 0);
|
||||
for (int64_t k = 0; k < reductionSize; ++k) {
|
||||
Value a = rewriter.create<vector::ExtractOp>(loc, lhs, k);
|
||||
Value b = rewriter.create<vector::ExtractOp>(loc, rhs, k);
|
||||
res = rewriter.create<vector::OuterProductOp>(loc, res.getType(), a, b,
|
||||
res, kind);
|
||||
}
|
||||
rewriter.replaceOp(op, res);
|
||||
return success();
|
||||
}
|
||||
|
||||
PatternRewriter &rewriter;
|
||||
Location loc;
|
||||
vector::CombiningKind kind;
|
||||
ArrayAttr iterators;
|
||||
SmallVector<AffineMap, 4> maps;
|
||||
Operation *op;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
/// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
|
||||
/// semantics to a reduction_size-unrolled sequence:
|
||||
/// ```
|
||||
|
@ -1844,104 +1910,68 @@ LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
|
|||
if (failed(filter(op)))
|
||||
return failure();
|
||||
|
||||
Location loc = op.getLoc();
|
||||
int64_t reductionSize = 0;
|
||||
VectorType lhsType = op.getLhsType();
|
||||
Value lhs = op.lhs(), rhs = op.rhs(), res = op.acc();
|
||||
|
||||
// Set up the parallel/reduction structure in right form.
|
||||
using MapList = ArrayRef<ArrayRef<AffineExpr>>;
|
||||
auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
|
||||
AffineExpr m, n, k;
|
||||
bindDims(rewriter.getContext(), m, n, k);
|
||||
static constexpr std::array<int64_t, 2> perm = {1, 0};
|
||||
auto iteratorTypes = op.iterator_types().getValue();
|
||||
SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
|
||||
if (isParallelIterator(iteratorTypes[0]) &&
|
||||
isParallelIterator(iteratorTypes[1]) &&
|
||||
isReductionIterator(iteratorTypes[2])) {
|
||||
|
||||
//
|
||||
// Two outer parallel, one inner reduction (matmat flavor).
|
||||
//
|
||||
if (maps == infer({{m, k}, {k, n}, {m, n}})) {
|
||||
// This is the classical row-major matmul. Just permute the lhs.
|
||||
reductionSize = lhsType.getDimSize(1);
|
||||
lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
|
||||
} else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
|
||||
UnrolledOuterProductEmitter e(rewriter, op);
|
||||
if (e.iters({Par(), Par(), Red()})) {
|
||||
// Classical row-major matmul: Just permute the lhs.
|
||||
if (e.layout({{m, k}, {k, n}, {m, n}}))
|
||||
return e.outer_prod(e.t(lhs), rhs, res, lhsType.getDimSize(1));
|
||||
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
|
||||
reductionSize = lhsType.getDimSize(1);
|
||||
lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
|
||||
rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
|
||||
} else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
|
||||
if (e.layout({{m, k}, {n, k}, {m, n}})) {
|
||||
Value tlhs = e.t(lhs);
|
||||
return e.outer_prod(tlhs, e.t(rhs), res, lhsType.getDimSize(1));
|
||||
}
|
||||
// No need to permute anything.
|
||||
reductionSize = lhsType.getDimSize(0);
|
||||
} else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
|
||||
if (e.layout({{k, m}, {k, n}, {m, n}}))
|
||||
return e.outer_prod(lhs, rhs, res, lhsType.getDimSize(0));
|
||||
// Just permute the rhs.
|
||||
reductionSize = lhsType.getDimSize(0);
|
||||
rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
|
||||
} else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
|
||||
// This is the classical row-major matmul. Just permute the lhs.
|
||||
reductionSize = lhsType.getDimSize(1);
|
||||
Value tmp = rhs;
|
||||
rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
|
||||
lhs = tmp;
|
||||
} else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
|
||||
if (e.layout({{k, m}, {n, k}, {m, n}}))
|
||||
return e.outer_prod(lhs, e.t(rhs), res, lhsType.getDimSize(0));
|
||||
// Transposed output: swap RHS and LHS.
|
||||
// Classical row-major matmul: permute the lhs.
|
||||
if (e.layout({{m, k}, {k, n}, {n, m}}))
|
||||
return e.outer_prod(rhs, e.t(lhs), res, lhsType.getDimSize(1));
|
||||
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
|
||||
reductionSize = lhsType.getDimSize(1);
|
||||
Value tmp = rhs;
|
||||
rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
|
||||
lhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
|
||||
} else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
|
||||
// No need to permute anything, but still swap lhs and rhs.
|
||||
reductionSize = lhsType.getDimSize(0);
|
||||
std::swap(lhs, rhs);
|
||||
} else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
|
||||
// Just permute the rhs.
|
||||
reductionSize = lhsType.getDimSize(0);
|
||||
Value tmp = lhs;
|
||||
lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
|
||||
rhs = tmp;
|
||||
} else {
|
||||
if (e.layout({{m, k}, {n, k}, {n, m}})) {
|
||||
Value trhs = e.t(rhs);
|
||||
return e.outer_prod(trhs, e.t(lhs), res, lhsType.getDimSize(1));
|
||||
}
|
||||
if (e.layout({{k, m}, {k, n}, {n, m}}))
|
||||
return e.outer_prod(rhs, lhs, res, lhsType.getDimSize(0));
|
||||
if (e.layout({{k, m}, {n, k}, {n, m}}))
|
||||
return e.outer_prod(e.t(rhs), lhs, res, lhsType.getDimSize(0));
|
||||
return failure();
|
||||
}
|
||||
} else if (isParallelIterator(iteratorTypes[0]) &&
|
||||
isReductionIterator(iteratorTypes[1])) {
|
||||
|
||||
//
|
||||
// One outer parallel, one inner reduction (matvec flavor)
|
||||
//
|
||||
if (maps == infer({{m, n}, {n}, {m}})) {
|
||||
if (e.iters({Par(), Red()})) {
|
||||
// Case mat-vec: transpose.
|
||||
reductionSize = lhsType.getDimSize(1);
|
||||
lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
|
||||
} else if (maps == infer({{n, m}, {n}, {m}})) {
|
||||
if (e.layout({{m, n}, {n}, {m}}))
|
||||
return e.outer_prod(e.t(lhs), rhs, res, lhsType.getDimSize(1));
|
||||
// Case mat-trans-vec: ready to go.
|
||||
reductionSize = lhsType.getDimSize(0);
|
||||
} else if (maps == infer({{n}, {m, n}, {m}})) {
|
||||
if (e.layout({{n, m}, {n}, {m}}))
|
||||
return e.outer_prod(lhs, rhs, res, lhsType.getDimSize(0));
|
||||
// Case vec-mat: swap and transpose.
|
||||
reductionSize = lhsType.getDimSize(0);
|
||||
std::swap(lhs, rhs);
|
||||
lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
|
||||
} else if (maps == infer({{n}, {n, m}, {m}})) {
|
||||
if (e.layout({{n}, {m, n}, {m}}))
|
||||
return e.outer_prod(e.t(rhs), lhs, res, lhsType.getDimSize(0));
|
||||
// Case vec-mat-trans: swap and ready to go.
|
||||
reductionSize = lhsType.getDimSize(0);
|
||||
std::swap(lhs, rhs);
|
||||
} else {
|
||||
if (e.layout({{n}, {n, m}, {m}}))
|
||||
return e.outer_prod(rhs, lhs, res, lhsType.getDimSize(0));
|
||||
return failure();
|
||||
}
|
||||
} else {
|
||||
return failure();
|
||||
}
|
||||
assert(reductionSize > 0);
|
||||
|
||||
// Unroll outer-products along reduction.
|
||||
for (int64_t k = 0; k < reductionSize; ++k) {
|
||||
Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, k);
|
||||
Value b = rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, k);
|
||||
res = rewriter.create<vector::OuterProductOp>(op.getLoc(), res.getType(), a,
|
||||
b, res, op.kind());
|
||||
}
|
||||
rewriter.replaceOp(op, res);
|
||||
return success();
|
||||
return failure();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
|
|
Loading…
Reference in New Issue