[mlir][vector] Add new lowering mode to vector.contractionOp

Add lowering for cases where the reduction dimension is fully unrolled.
It is common to unroll the reduction dimension, therefore we would want
to lower the contractions to an elementwise vector op in this case.

Differential Revision: https://reviews.llvm.org/D126120
This commit is contained in:
Thomas Raoux 2022-05-24 14:16:00 +00:00
parent 6c80267d0f
commit 89aaa2d033
4 changed files with 246 additions and 45 deletions

View File

@ -49,6 +49,9 @@ enum class VectorContractLowering {
Matmul = 1,
/// Lower to `vector.outerproduct`.
OuterProduct = 2,
/// Lower contract with all reduction dimensions unrolled to 1 to a vector
/// elementwise operations.
ParallelArith = 3,
};
/// Enum to control the splitting of `vector.transfer` operations into
/// in-bounds and out-of-bounds variants.

View File

@ -144,6 +144,59 @@ static SmallVector<IntType, 4> extractVector(ArrayAttr arrayAttr) {
[](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
}
/// Helper to create arithmetic operation associated with a kind of contraction.
static Optional<Value> createContractArithOp(Location loc, Value x, Value y,
Value acc,
vector::CombiningKind kind,
PatternRewriter &rewriter,
bool isInt) {
using vector::CombiningKind;
Value mul;
if (isInt) {
if (kind == CombiningKind::MINF || kind == CombiningKind::MAXF)
// Only valid for floating point types.
return Optional<Value>();
mul = rewriter.create<arith::MulIOp>(loc, x, y);
} else {
// Float case.
if (kind == CombiningKind::AND || kind == CombiningKind::MINUI ||
kind == CombiningKind::MINSI || kind == CombiningKind::MAXUI ||
kind == CombiningKind::MAXSI || kind == CombiningKind::OR ||
kind == CombiningKind::XOR)
// Only valid for integer types.
return Optional<Value>();
// Special case for fused multiply-add.
if (acc && acc.getType().isa<VectorType>() && kind == CombiningKind::ADD) {
return Optional<Value>(rewriter.create<vector::FMAOp>(loc, x, y, acc));
}
mul = rewriter.create<arith::MulFOp>(loc, x, y);
}
if (!acc)
return Optional<Value>(mul);
return makeArithReduction(rewriter, loc, kind, mul, acc);
}
/// Return the positions of the reductions in the given map.
static SmallVector<int64_t> getReductionIndex(AffineMap map,
ArrayAttr iteratorTypes) {
SmallVector<int64_t> dimsIdx;
for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
if (isReductionIterator(iteratorTypes[map.getDimPosition(i)]))
dimsIdx.push_back(i);
}
return dimsIdx;
}
/// Look for a given dimension in an affine map and return its position. Return
/// llvm::None if the dimension is not in the map results.
static llvm::Optional<unsigned> getDimPosition(AffineMap map, unsigned dim) {
for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
if (map.getDimPosition(i) == dim)
return i;
}
return llvm::None;
}
namespace {
/// ShapeCastOpFolder folds cancelling ShapeCastOps away.
@ -498,9 +551,8 @@ public:
if (!rhsType) {
// Special case: AXPY operation.
Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.getRhs());
Optional<Value> mult =
isInt ? genMultI(loc, op.getLhs(), b, acc, kind, rewriter)
: genMultF(loc, op.getLhs(), b, acc, kind, rewriter);
Optional<Value> mult = createContractArithOp(loc, op.getLhs(), b, acc,
kind, rewriter, isInt);
if (!mult.hasValue())
return failure();
rewriter.replaceOp(op, mult.getValue());
@ -518,8 +570,7 @@ public:
if (acc)
r = rewriter.create<vector::ExtractOp>(loc, rhsType, acc, pos);
Optional<Value> m =
isInt ? genMultI(loc, a, op.getRhs(), r, kind, rewriter)
: genMultF(loc, a, op.getRhs(), r, kind, rewriter);
createContractArithOp(loc, a, op.getRhs(), r, kind, rewriter, isInt);
if (!m.hasValue())
return failure();
result = rewriter.create<vector::InsertOp>(loc, resType, m.getValue(),
@ -528,48 +579,127 @@ public:
rewriter.replaceOp(op, result);
return success();
}
};
/// Lower vector.contract with all size one reduction dimensions to
/// elementwise ops when possible.
struct ContractOpToElementwise
: public OpRewritePattern<vector::ContractionOp> {
using OpRewritePattern::OpRewritePattern;
using FilterConstraintType =
std::function<LogicalResult(vector::ContractionOp op)>;
static LogicalResult defaultFilter(vector::ContractionOp op) {
return success();
}
ContractOpToElementwise(
vector::VectorTransformsOptions vectorTransformOptions,
MLIRContext *context,
const FilterConstraintType &constraint = defaultFilter)
: OpRewritePattern<vector::ContractionOp>(context),
vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
PatternRewriter &rewriter) const override {
// TODO: implement masks
if (llvm::size(contractOp.getMasks()) != 0)
return failure();
if (failed(filter(contractOp)))
return failure();
if (vectorTransformOptions.vectorContractLowering !=
vector::VectorContractLowering::ParallelArith)
return failure();
ArrayRef<int64_t> lhsShape = contractOp.getLhsType().getShape();
ArrayRef<int64_t> rhsShape = contractOp.getRhsType().getShape();
AffineMap lhsMap = contractOp.getIndexingMaps()[0];
AffineMap rhsMap = contractOp.getIndexingMaps()[1];
SmallVector<int64_t> lhsReductionDims =
getReductionIndex(lhsMap, contractOp.getIteratorTypes());
SmallVector<int64_t> rhsReductionDims =
getReductionIndex(rhsMap, contractOp.getIteratorTypes());
// All the reduction dimensions must be a size 1.
for (int64_t dim : lhsReductionDims) {
if (lhsShape[dim] != 1)
return failure();
}
for (int64_t dim : rhsReductionDims) {
if (rhsShape[dim] != 1)
return failure();
}
AffineMap accMap = contractOp.getIndexingMaps()[2];
unsigned numParallelDims = accMap.getNumResults();
unsigned numLhsDimToBroadcast =
numParallelDims - (lhsMap.getNumResults() - lhsReductionDims.size());
unsigned numRhsDimToBroadcast =
numParallelDims - (rhsMap.getNumResults() - rhsReductionDims.size());
SmallVector<int64_t> lhsDims;
SmallVector<int64_t> lhsTranspose;
SmallVector<int64_t> rhsDims;
SmallVector<int64_t> rhsTranspose;
for (int64_t dim : lhsReductionDims)
lhsTranspose.push_back(numLhsDimToBroadcast + dim);
for (int64_t dim : rhsReductionDims)
rhsTranspose.push_back(numRhsDimToBroadcast + dim);
// Loop through the parallel dimensions to calculate the dimensions to
// broadcast and to permute in order to extract only parallel dimensions.
for (unsigned i = 0; i < numParallelDims; i++) {
llvm::Optional<unsigned> lhsDim =
getDimPosition(lhsMap, accMap.getDimPosition(i));
if (lhsDim) {
lhsTranspose.push_back(numLhsDimToBroadcast + *lhsDim);
} else {
// If the parallel dimension doesn't exist we will have to broadcast it.
lhsDims.push_back(
contractOp.getResultType().cast<VectorType>().getDimSize(i));
lhsTranspose.push_back(lhsDims.size() - 1);
}
llvm::Optional<unsigned> rhsDim =
getDimPosition(rhsMap, accMap.getDimPosition(i));
if (rhsDim) {
rhsTranspose.push_back(numRhsDimToBroadcast + *rhsDim);
} else {
// If the parallel dimension doesn't exist we will have to broadcast it.
rhsDims.push_back(
contractOp.getResultType().cast<VectorType>().getDimSize(i));
rhsTranspose.push_back(rhsDims.size() - 1);
}
}
Value newLhs = contractOp.getLhs();
Value newRhs = contractOp.getRhs();
Location loc = contractOp.getLoc();
if (!lhsDims.empty()) {
lhsDims.append(lhsShape.begin(), lhsShape.end());
auto expandedType =
VectorType::get(lhsDims, contractOp.getLhsType().getElementType());
newLhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newLhs);
}
if (!rhsDims.empty()) {
rhsDims.append(rhsShape.begin(), rhsShape.end());
auto expandedType =
VectorType::get(rhsDims, contractOp.getRhsType().getElementType());
newRhs = rewriter.create<vector::BroadcastOp>(loc, expandedType, newRhs);
}
bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex();
newLhs = rewriter.create<vector::TransposeOp>(loc, newLhs, lhsTranspose);
newRhs = rewriter.create<vector::TransposeOp>(loc, newRhs, rhsTranspose);
SmallVector<int64_t, 4> lhsOffsets(lhsReductionDims.size(), 0);
SmallVector<int64_t, 4> rhsOffsets(rhsReductionDims.size(), 0);
newLhs = rewriter.create<vector::ExtractOp>(
loc, newLhs, rewriter.getI64ArrayAttr(lhsOffsets));
newRhs = rewriter.create<vector::ExtractOp>(
loc, newRhs, rewriter.getI64ArrayAttr(rhsOffsets));
Optional<Value> result =
createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(),
contractOp.getKind(), rewriter, isInt);
rewriter.replaceOp(contractOp, {*result});
return success();
}
private:
static Optional<Value> genMultI(Location loc, Value x, Value y, Value acc,
vector::CombiningKind kind,
PatternRewriter &rewriter) {
using vector::CombiningKind;
auto mul = rewriter.create<arith::MulIOp>(loc, x, y);
if (!acc)
return Optional<Value>(mul);
if (kind == CombiningKind::MINF || kind == CombiningKind::MAXF)
// Only valid for floating point types.
return Optional<Value>();
return makeArithReduction(rewriter, loc, kind, mul, acc);
}
static Optional<Value> genMultF(Location loc, Value x, Value y, Value acc,
vector::CombiningKind kind,
PatternRewriter &rewriter) {
using vector::CombiningKind;
// Special case for fused multiply-add.
if (acc && kind == CombiningKind::ADD) {
return Optional<Value>(rewriter.create<vector::FMAOp>(loc, x, y, acc));
}
auto mul = rewriter.create<arith::MulFOp>(loc, x, y);
if (!acc)
return Optional<Value>(mul);
if (kind == CombiningKind::ADD || kind == CombiningKind::AND ||
kind == CombiningKind::MINUI || kind == CombiningKind::MINSI ||
kind == CombiningKind::MAXUI || kind == CombiningKind::MAXSI ||
kind == CombiningKind::OR || kind == CombiningKind::XOR)
// Already handled or only valid for integer types.
return Optional<Value>();
return makeArithReduction(rewriter, loc, kind, mul, acc);
}
/// Options to control the vector patterns.
vector::VectorTransformsOptions vectorTransformOptions;
FilterConstraintType filter;
};
/// Progressive lowering of ConstantMaskOp.
@ -1594,6 +1724,9 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
ContractionOpToDotLowering pat3(vectorTransformOptions, ctx);
if (succeeded(pat3.matchAndRewrite(op, rewriter)))
return success();
ContractOpToElementwise pat4(vectorTransformOptions, ctx);
if (succeeded(pat4.matchAndRewrite(op, rewriter)))
return success();
// Find first batch dimension in LHS/RHS, and lower when found.
std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();

View File

@ -2,6 +2,7 @@
// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-lower-matrix-intrinsics=1 | FileCheck %s --check-prefix=MATRIX
// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-outerproduct=1 | FileCheck %s --check-prefix=OUTERPRODUCT
// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-filter-outerproduct=1 | FileCheck %s --check-prefix=FILTEROUTERPRODUCT
// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-parallel-arith=1 | FileCheck %s --check-prefix=PARALLEL
#dotp_accesses = [
affine_map<(i) -> (i)>,
@ -1104,3 +1105,54 @@ func.func @matmul_4_not_filtered(%arg0: vector<3x4xf32>, %arg1: vector<4x4xf32>,
: vector<3x4xf32>, vector<4x4xf32> into vector<3x4xf32>
return %0 : vector<3x4xf32>
}
// PARALLEL-LABEL: func @parrallel_contract_lowering
// PARALLEL: %[[E0:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1x4xf32>
// PARALLEL: %[[E1:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1x4xf32>
// PARALLEL: %[[F:.*]] = vector.fma %[[E0]], %[[E1]], %{{.*}} : vector<4xf32>
// PARALLEL: return %[[F]] : vector<4xf32>
func.func @parrallel_contract_lowering(%arg0: vector<1x1x4xf32>, %arg1: vector<1x1x4xf32>, %arg2: vector<4xf32>) -> vector<4xf32> {
%0 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"], kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<1x1x4xf32>, vector<1x1x4xf32> into vector<4xf32>
return %0 : vector<4xf32>
}
// PARALLEL-LABEL: func @parrallel_contract_lowering_broadcast
// PARALLEL: %[[B:.*]] = vector.broadcast %{{.*}} : vector<1x1xf32> to vector<4x1x1xf32>
// PARALLEL: %[[T:.*]] = vector.transpose %[[B]], [1, 2, 0] : vector<4x1x1xf32> to vector<1x1x4xf32>
// PARALLEL: %[[E0:.*]] = vector.extract %[[T]][0, 0] : vector<1x1x4xf32>
// PARALLEL: %[[E1:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1x4xf32>
// PARALLEL: %[[F:.*]] = vector.fma %[[E0]], %[[E1]], %{{.*}} : vector<4xf32>
// PARALLEL: return %[[F]] : vector<4xf32>
func.func @parrallel_contract_lowering_broadcast(%arg0: vector<1x1xf32>, %arg1: vector<1x1x4xf32>, %arg2: vector<4xf32>) -> vector<4xf32> {
%0 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"], kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<1x1xf32>, vector<1x1x4xf32> into vector<4xf32>
return %0 : vector<4xf32>
}
// PARALLEL-LABEL: func @parrallel_contract_lowering
// PARALLEL: %[[B:.*]] = vector.broadcast %{{.*}} : vector<1x1xf32> to vector<4x1x1xf32>
// PARALLEL: %[[T0:.*]] = vector.transpose %[[B]], [1, 2, 0] : vector<4x1x1xf32> to vector<1x1x4xf32>
// PARALLEL: %[[T1:.*]] = vector.transpose %{{.*}}, [0, 2, 1] : vector<1x4x1xf32> to vector<1x1x4xf32>
// PARALLEL: %[[E0:.*]] = vector.extract %[[T0]][0, 0] : vector<1x1x4xf32>
// PARALLEL: %[[E1:.*]] = vector.extract %[[T1]][0, 0] : vector<1x1x4xf32>
// PARALLEL: %[[F:.*]] = vector.fma %[[E0]], %[[E1]], %arg2 : vector<4xf32>
// PARALLEL: return %[[F]] : vector<4xf32>
func.func @parrallel_contract_lowering_transpose(%arg0: vector<1x1xf32>, %arg1: vector<1x4x1xf32>, %arg2: vector<4xf32>) -> vector<4xf32> {
%0 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d1, d0, d2)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"], kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<1x1xf32>, vector<1x4x1xf32> into vector<4xf32>
return %0 : vector<4xf32>
}
// PARALLEL-LABEL: func @parrallel_contract_lowering_scalar
// PARALLEL: %[[E0:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1xf32>
// PARALLEL: %[[E1:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1xf32>
// PARALLEL: %[[M:.*]] = arith.mulf %[[E0]], %[[E1]] : f32
// PARALLEL: %[[A:.*]] = arith.addf %[[M]], %{{.*}} : f32
// PARALLEL: return %[[A]] : f32
func.func @parrallel_contract_lowering_scalar(%arg0: vector<1x1xf32>, %arg1: vector<1x1xf32>, %arg2: f32) -> f32 {
%0 = vector.contract {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> ()>],
iterator_types = ["reduction", "reduction"], kind = #vector.kind<add>}
%arg0, %arg1, %arg2 : vector<1x1xf32>, vector<1x1xf32> into f32
return %0 : f32
}

View File

@ -135,6 +135,10 @@ struct TestVectorContractionLowering
llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for "
"vectors of size 4."),
llvm::cl::init(false)};
Option<bool> lowerToParallelArith{
*this, "vector-parallel-arith",
llvm::cl::desc("Lower vector.contract to elementwise vector ops."),
llvm::cl::init(false)};
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
@ -165,6 +169,15 @@ struct TestVectorContractionLowering
return;
}
if (lowerToParallelArith) {
vector::populateVectorContractLoweringPatterns(
patterns,
vector::VectorTransformsOptions().setVectorTransformsOptions(
vector::VectorContractLowering::ParallelArith));
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
return;
}
// Test on all contract lowering patterns.
VectorContractLowering contractLowering = VectorContractLowering::Dot;
if (lowerToFlatMatrix)