Revert "[mlir][Vector] NFC - Compress vector to outerproduct lowering."

This reverts commit db188adfb1.

Breaks the GCC tests, likely because of some order of evaluation
difference between clang and gcc.
This commit is contained in:
Mehdi Amini 2021-07-02 17:55:06 +00:00
parent c7c5a1c9ae
commit 4525d52c73
1 changed files with 88 additions and 114 deletions

View File

@ -1816,72 +1816,6 @@ ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
return success();
}
namespace {
struct IteratorType {
IteratorType(StringRef strRef) : strRef(strRef) {}
bool isOfType(Attribute attr) const {
auto sAttr = attr.dyn_cast<StringAttr>();
return sAttr && sAttr.getValue() == strRef;
}
StringRef strRef;
};
struct Par : public IteratorType {
Par() : IteratorType(getParallelIteratorTypeName()) {}
};
struct Red : public IteratorType {
Red() : IteratorType(getReductionIteratorTypeName()) {}
};
// Unroll outer-products along reduction.
struct UnrolledOuterProductEmitter {
using MapList = ArrayRef<ArrayRef<AffineExpr>>;
UnrolledOuterProductEmitter(PatternRewriter &rewriter,
vector::ContractionOp op)
: rewriter(rewriter), loc(op.getLoc()), kind(op.kind()),
iterators(op.iterator_types()), maps(op.getIndexingMaps()), op(op) {}
Value t(Value v) {
static constexpr std::array<int64_t, 2> perm = {1, 0};
return rewriter.create<vector::TransposeOp>(loc, v, perm);
}
bool iters(ArrayRef<IteratorType> its) {
if (its.size() != iterators.size())
return false;
for (int i = 0, e = its.size(); i != e; ++i) {
if (!its[i].isOfType(iterators[i]))
return false;
}
return true;
}
bool layout(MapList l) {
auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
return maps == infer(l);
}
LogicalResult outer_prod(Value lhs, Value rhs, Value res, int reductionSize) {
assert(reductionSize > 0);
for (int64_t k = 0; k < reductionSize; ++k) {
Value a = rewriter.create<vector::ExtractOp>(loc, lhs, k);
Value b = rewriter.create<vector::ExtractOp>(loc, rhs, k);
res = rewriter.create<vector::OuterProductOp>(loc, res.getType(), a, b,
res, kind);
}
rewriter.replaceOp(op, res);
return success();
}
PatternRewriter &rewriter;
Location loc;
vector::CombiningKind kind;
ArrayAttr iterators;
SmallVector<AffineMap, 4> maps;
Operation *op;
};
} // namespace
/// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
/// semantics to a reduction_size-unrolled sequence:
/// ```
@ -1910,64 +1844,104 @@ LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
if (failed(filter(op)))
return failure();
Location loc = op.getLoc();
int64_t reductionSize = 0;
VectorType lhsType = op.getLhsType();
Value lhs = op.lhs(), rhs = op.rhs(), res = op.acc();
// Set up the parallel/reduction structure in right form.
using MapList = ArrayRef<ArrayRef<AffineExpr>>;
auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
AffineExpr m, n, k;
bindDims(rewriter.getContext(), m, n, k);
//
// Two outer parallel, one inner reduction (matmat flavor).
//
UnrolledOuterProductEmitter e(rewriter, op);
if (e.iters({Par(), Par(), Red()})) {
// Classical row-major matmul: Just permute the lhs.
if (e.layout({{m, k}, {k, n}, {m, n}}))
return e.outer_prod(e.t(lhs), rhs, res, lhsType.getDimSize(1));
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
if (e.layout({{m, k}, {n, k}, {m, n}}))
return e.outer_prod(e.t(lhs), e.t(rhs), res, lhsType.getDimSize(1));
// No need to permute anything.
if (e.layout({{k, m}, {k, n}, {m, n}}))
return e.outer_prod(lhs, rhs, res, lhsType.getDimSize(0));
// Just permute the rhs.
if (e.layout({{k, m}, {n, k}, {m, n}}))
return e.outer_prod(lhs, e.t(rhs), res, lhsType.getDimSize(0));
// Transposed output: swap RHS and LHS.
// Classical row-major matmul: permute the lhs.
if (e.layout({{m, k}, {k, n}, {n, m}}))
return e.outer_prod(rhs, e.t(lhs), res, lhsType.getDimSize(1));
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
if (e.layout({{m, k}, {n, k}, {n, m}}))
return e.outer_prod(e.t(rhs), e.t(lhs), res, lhsType.getDimSize(1));
if (e.layout({{k, m}, {k, n}, {n, m}}))
return e.outer_prod(rhs, lhs, res, lhsType.getDimSize(0));
if (e.layout({{k, m}, {n, k}, {n, m}}))
return e.outer_prod(e.t(rhs), lhs, res, lhsType.getDimSize(0));
static constexpr std::array<int64_t, 2> perm = {1, 0};
auto iteratorTypes = op.iterator_types().getValue();
SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
if (isParallelIterator(iteratorTypes[0]) &&
isParallelIterator(iteratorTypes[1]) &&
isReductionIterator(iteratorTypes[2])) {
//
// Two outer parallel, one inner reduction (matmat flavor).
//
if (maps == infer({{m, k}, {k, n}, {m, n}})) {
// This is the classical row-major matmul. Just permute the lhs.
reductionSize = lhsType.getDimSize(1);
lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
} else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
reductionSize = lhsType.getDimSize(1);
lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
} else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
// No need to permute anything.
reductionSize = lhsType.getDimSize(0);
} else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
// Just permute the rhs.
reductionSize = lhsType.getDimSize(0);
rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
} else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
// This is the classical row-major matmul. Just permute the lhs.
reductionSize = lhsType.getDimSize(1);
Value tmp = rhs;
rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
lhs = tmp;
} else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
reductionSize = lhsType.getDimSize(1);
Value tmp = rhs;
rhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
lhs = rewriter.create<vector::TransposeOp>(loc, tmp, perm);
} else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
// No need to permute anything, but still swap lhs and rhs.
reductionSize = lhsType.getDimSize(0);
std::swap(lhs, rhs);
} else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
// Just permute the rhs.
reductionSize = lhsType.getDimSize(0);
Value tmp = lhs;
lhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
rhs = tmp;
} else {
return failure();
}
} else if (isParallelIterator(iteratorTypes[0]) &&
isReductionIterator(iteratorTypes[1])) {
//
// One outer parallel, one inner reduction (matvec flavor)
//
if (maps == infer({{m, n}, {n}, {m}})) {
// Case mat-vec: transpose.
reductionSize = lhsType.getDimSize(1);
lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
} else if (maps == infer({{n, m}, {n}, {m}})) {
// Case mat-trans-vec: ready to go.
reductionSize = lhsType.getDimSize(0);
} else if (maps == infer({{n}, {m, n}, {m}})) {
// Case vec-mat: swap and transpose.
reductionSize = lhsType.getDimSize(0);
std::swap(lhs, rhs);
lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
} else if (maps == infer({{n}, {n, m}, {m}})) {
// Case vec-mat-trans: swap and ready to go.
reductionSize = lhsType.getDimSize(0);
std::swap(lhs, rhs);
} else {
return failure();
}
} else {
return failure();
}
assert(reductionSize > 0);
//
// One outer parallel, one inner reduction (matvec flavor)
//
if (e.iters({Par(), Red()})) {
// Case mat-vec: transpose.
if (e.layout({{m, n}, {n}, {m}}))
return e.outer_prod(e.t(lhs), rhs, res, lhsType.getDimSize(1));
// Case mat-trans-vec: ready to go.
if (e.layout({{n, m}, {n}, {m}}))
return e.outer_prod(lhs, rhs, res, lhsType.getDimSize(0));
// Case vec-mat: swap and transpose.
if (e.layout({{n}, {m, n}, {m}}))
return e.outer_prod(e.t(rhs), lhs, res, lhsType.getDimSize(0));
// Case vec-mat-trans: swap and ready to go.
if (e.layout({{n}, {n, m}, {m}}))
return e.outer_prod(rhs, lhs, res, lhsType.getDimSize(0));
return failure();
// Unroll outer-products along reduction.
for (int64_t k = 0; k < reductionSize; ++k) {
Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, k);
Value b = rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, k);
res = rewriter.create<vector::OuterProductOp>(op.getLoc(), res.getType(), a,
b, res, op.kind());
}
return failure();
rewriter.replaceOp(op, res);
return success();
}
LogicalResult