forked from OSchip/llvm-project
lowerParallel is also called on unit-size, one-sided reduction dims
See: https://gist.github.com/bjacob/d8be8ec7e70ed0be4b3a5794ced2a7e8 Differential Revision: https://reviews.llvm.org/D129096
This commit is contained in:
parent
3968936b92
commit
6870a50f43
|
@ -527,11 +527,12 @@ private:
|
|||
vector::VectorTransformsOptions vectorTransformOptions;
|
||||
FilterConstraintType filter;
|
||||
// Lower one parallel dimension.
|
||||
Value lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
|
||||
int64_t rhsIndex, PatternRewriter &rewriter) const;
|
||||
FailureOr<Value> lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
|
||||
int64_t rhsIndex,
|
||||
PatternRewriter &rewriter) const;
|
||||
// Lower one reduction dimension.
|
||||
Value lowerReduction(vector::ContractionOp op,
|
||||
PatternRewriter &rewriter) const;
|
||||
FailureOr<Value> lowerReduction(vector::ContractionOp op,
|
||||
PatternRewriter &rewriter) const;
|
||||
};
|
||||
|
||||
} // namespace vector
|
||||
|
|
|
@ -1794,7 +1794,10 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
|
|||
if (!batchDimMap.empty()) {
|
||||
int64_t lhsIndex = batchDimMap[0].first;
|
||||
int64_t rhsIndex = batchDimMap[0].second;
|
||||
rewriter.replaceOp(op, lowerParallel(op, lhsIndex, rhsIndex, rewriter));
|
||||
auto newOp = lowerParallel(op, lhsIndex, rhsIndex, rewriter);
|
||||
if (failed(newOp))
|
||||
return failure();
|
||||
rewriter.replaceOp(op, newOp.value());
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -1812,8 +1815,10 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
|
|||
VectorType lhsType = op.getLhsType();
|
||||
for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
|
||||
if (lhsContractingDimSet.count(lhsIndex) == 0) {
|
||||
rewriter.replaceOp(
|
||||
op, lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter));
|
||||
auto newOp = lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter);
|
||||
if (failed(newOp))
|
||||
return failure();
|
||||
rewriter.replaceOp(op, newOp.value());
|
||||
return success();
|
||||
}
|
||||
}
|
||||
|
@ -1822,15 +1827,20 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
|
|||
VectorType rhsType = op.getRhsType();
|
||||
for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
|
||||
if (rhsContractingDimSet.count(rhsIndex) == 0) {
|
||||
rewriter.replaceOp(
|
||||
op, lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter));
|
||||
auto newOp = lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter);
|
||||
if (failed(newOp))
|
||||
return failure();
|
||||
rewriter.replaceOp(op, newOp.value());
|
||||
return success();
|
||||
}
|
||||
}
|
||||
|
||||
// Lower the first remaining reduction dimension.
|
||||
if (!contractingDimMap.empty()) {
|
||||
rewriter.replaceOp(op, lowerReduction(op, rewriter));
|
||||
auto newOp = lowerReduction(op, rewriter);
|
||||
if (failed(newOp))
|
||||
return failure();
|
||||
rewriter.replaceOp(op, newOp.value());
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -1838,10 +1848,12 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
|
|||
}
|
||||
|
||||
// Lower one parallel dimension.
|
||||
// Incidentally also tolerates unit-size (hence trivial) reduction dimensions.
|
||||
// TODO: consider reusing existing contract unrolling
|
||||
Value ContractionOpLowering::lowerParallel(vector::ContractionOp op,
|
||||
int64_t lhsIndex, int64_t rhsIndex,
|
||||
PatternRewriter &rewriter) const {
|
||||
FailureOr<Value>
|
||||
ContractionOpLowering::lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
|
||||
int64_t rhsIndex,
|
||||
PatternRewriter &rewriter) const {
|
||||
VectorType lhsType = op.getLhsType();
|
||||
VectorType rhsType = op.getRhsType();
|
||||
VectorType resType = op.getResultType().cast<VectorType>();
|
||||
|
@ -1851,18 +1863,34 @@ Value ContractionOpLowering::lowerParallel(vector::ContractionOp op,
|
|||
int64_t dimSize = -1;
|
||||
if (lhsIndex >= 0) {
|
||||
iterIndex = iMap[0].getDimPosition(lhsIndex);
|
||||
assert((rhsIndex < 0 || iterIndex == iMap[1].getDimPosition(rhsIndex)) &&
|
||||
"parallel index should be free in LHS or batch in LHS/RHS");
|
||||
if (rhsIndex >= 0 && iterIndex != iMap[1].getDimPosition(rhsIndex))
|
||||
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
|
||||
diag << "expected lhsIndex=" << lhsIndex << " and rhsIndex=" << rhsIndex
|
||||
<< " to map to the same dimension";
|
||||
});
|
||||
dimSize = lhsType.getDimSize(lhsIndex);
|
||||
} else {
|
||||
assert(rhsIndex >= 0 && "missing parallel index");
|
||||
} else if (rhsIndex >= 0) {
|
||||
iterIndex = iMap[1].getDimPosition(rhsIndex);
|
||||
dimSize = rhsType.getDimSize(rhsIndex);
|
||||
}
|
||||
assert(iterIndex >= 0 && "parallel index not listed in operand mapping");
|
||||
Optional<int64_t> lookup = getResultIndex(iMap[2], iterIndex);
|
||||
assert(lookup.has_value() && "parallel index not listed in reduction");
|
||||
int64_t resIndex = lookup.getValue();
|
||||
if (iterIndex < 0)
|
||||
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
|
||||
diag << "expected either lhsIndex=" << lhsIndex
|
||||
<< " or rhsIndex=" << rhsIndex << " to be nonnegative";
|
||||
});
|
||||
// getValueOr(-1) means that we tolerate a dimension not appearing
|
||||
// in the result map. That can't happen for actual parallel iterators, but
|
||||
// the caller ContractionOpLowering::matchAndRewrite is currently calling
|
||||
// lowerParallel also for the case of unit-size reduction dims appearing only
|
||||
// on one of LHS or RHS, not both. At the moment, such cases are created by
|
||||
// CastAwayContractionLeadingOneDim, so we need to either support that or
|
||||
// modify that pattern.
|
||||
int64_t resIndex = getResultIndex(iMap[2], iterIndex).getValueOr(-1);
|
||||
if (resIndex == -1 && dimSize != 1)
|
||||
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
|
||||
diag << "expected the dimension for iterIndex=" << iterIndex
|
||||
<< " to either appear in the result map, or to be a unit dimension";
|
||||
});
|
||||
// Construct new iterator types and affine map array attribute.
|
||||
std::array<AffineMap, 3> lowIndexingMaps = {
|
||||
adjustMap(iMap[0], iterIndex, rewriter),
|
||||
|
@ -1888,33 +1916,49 @@ Value ContractionOpLowering::lowerParallel(vector::ContractionOp op,
|
|||
}
|
||||
|
||||
// Lower one reduction dimension.
|
||||
Value ContractionOpLowering::lowerReduction(vector::ContractionOp op,
|
||||
PatternRewriter &rewriter) const {
|
||||
FailureOr<Value>
|
||||
ContractionOpLowering::lowerReduction(vector::ContractionOp op,
|
||||
PatternRewriter &rewriter) const {
|
||||
auto loc = op.getLoc();
|
||||
VectorType lhsType = op.getLhsType();
|
||||
VectorType rhsType = op.getRhsType();
|
||||
Type resType = op.getResultType();
|
||||
assert(!resType.isa<VectorType>());
|
||||
if (resType.isa<VectorType>())
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"did not expect a VectorType result");
|
||||
bool isInt = resType.isa<IntegerType>();
|
||||
// Use iterator index 0.
|
||||
int64_t iterIndex = 0;
|
||||
SmallVector<AffineMap, 4> iMap = op.getIndexingMaps();
|
||||
Optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex);
|
||||
Optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex);
|
||||
assert(lookupLhs.has_value() && "missing LHS parallel index");
|
||||
assert(lookupRhs.has_value() && "missing RHS parallel index");
|
||||
if (!lookupLhs.hasValue())
|
||||
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
|
||||
diag << "expected iterIndex=" << iterIndex << "to map to a LHS dimension";
|
||||
});
|
||||
if (!lookupRhs.hasValue())
|
||||
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
|
||||
diag << "expected iterIndex=" << iterIndex << "to map to a RHS dimension";
|
||||
});
|
||||
int64_t lhsIndex = lookupLhs.getValue();
|
||||
int64_t rhsIndex = lookupRhs.getValue();
|
||||
int64_t dimSize = lhsType.getDimSize(lhsIndex);
|
||||
assert(dimSize == rhsType.getDimSize(rhsIndex) && "corrupt shape");
|
||||
if (dimSize != rhsType.getDimSize(rhsIndex))
|
||||
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
|
||||
diag << "expect LHS dimension " << lhsIndex
|
||||
<< " to have the same size as RHS dimension " << rhsIndex;
|
||||
});
|
||||
// Base case.
|
||||
if (lhsType.getRank() == 1) {
|
||||
assert(rhsType.getRank() == 1 && "corrupt contraction");
|
||||
if (rhsType.getRank() != 1)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "When LHS has rank 1, expected also RHS to have rank 1");
|
||||
Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter);
|
||||
auto kind = vector::CombiningKind::ADD;
|
||||
if (auto acc = op.getAcc())
|
||||
return rewriter.create<vector::ReductionOp>(loc, kind, m, acc);
|
||||
return rewriter.create<vector::ReductionOp>(loc, kind, m);
|
||||
return rewriter.create<vector::ReductionOp>(loc, kind, m, acc)
|
||||
.getResult();
|
||||
return rewriter.create<vector::ReductionOp>(loc, kind, m).getResult();
|
||||
}
|
||||
// Construct new iterator types and affine map array attribute.
|
||||
std::array<AffineMap, 3> lowIndexingMaps = {
|
||||
|
|
|
@ -858,6 +858,34 @@ func.func @genbool_var_3d(%arg0: index, %arg1: index, %arg2: index) -> vector<2x
|
|||
return %0 : vector<2x1x7xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @contract_one_sided_unit_reduction_dim
|
||||
// CHECK-SAME: (%[[A0:.+]]: vector<1x2xi32>, %[[A1:.+]]: vector<2x2xi32>, %[[A2:.+]]: vector<2xi32>)
|
||||
// CHECK-DAG: %[[C:.+]] = arith.constant dense<0> : vector<2xi32>
|
||||
// CHECK-DAG: %[[E00:.+]] = vector.extract %[[A0]][0] : vector<1x2xi32>
|
||||
// CHECK-DAG: %[[E10:.+]] = vector.extract %[[A1]][0] : vector<2x2xi32>
|
||||
// CHECK: %[[M0:.+]] = arith.muli %[[E10]], %[[E00]] : vector<2xi32>
|
||||
// CHECK: %[[R0:.+]] = vector.reduction <add>, %[[M0]] : vector<2xi32> into i32
|
||||
// CHECK: %[[I0:.+]] = vector.insert %[[R0]], %[[C]] [0] : i32 into vector<2xi32>
|
||||
// CHECK: %[[E11:.+]] = vector.extract %[[A1]][1] : vector<2x2xi32>
|
||||
// CHECK: %[[M1:.+]] = arith.muli %[[E11]], %[[E00]] : vector<2xi32>
|
||||
// CHECK: %[[R1:.+]] = vector.reduction <add>, %[[M1]] : vector<2xi32> into i32
|
||||
// CHECK: %[[I1:.+]] = vector.insert %[[R1]], %[[I0]] [1] : i32 into vector<2xi32>
|
||||
// CHECK: %[[S:.+]] = arith.addi %[[I1]], %[[A2]] : vector<2xi32>
|
||||
// CHECK: return %[[S]] : vector<2xi32>
|
||||
|
||||
func.func @contract_one_sided_unit_reduction_dim(%arg0 : vector<1x2xi32>, %arg1 : vector<2x2xi32>, %arg2 : vector<2xi32>) -> vector<2xi32> {
|
||||
%res = vector.contract {
|
||||
indexing_maps = [
|
||||
affine_map<(d0, d1, d2) -> (d0, d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d1, d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d1)>
|
||||
],
|
||||
iterator_types = ["reduction", "parallel", "reduction"],
|
||||
kind = #vector.kind<add>
|
||||
} %arg0, %arg1, %arg2 : vector<1x2xi32>, vector<2x2xi32>, vector<2xi32> into vector<2xi32>
|
||||
return %res : vector<2xi32>
|
||||
}
|
||||
|
||||
#matmat_accesses_0 = [
|
||||
affine_map<(m, n, k) -> (m, k)>,
|
||||
affine_map<(m, n, k) -> (k, n)>,
|
||||
|
|
Loading…
Reference in New Issue