[mlir][vector] Add patterns to convert multidimreduce to vector.contract

add several patterns that will simplify contraction vectorization in the
future. With those canonicalizationns we will be able to remove the special
case for contration during vectorization and rely on those transformations to
avoid materizalizing broadcast ops.

Differential Revision: https://reviews.llvm.org/D112121
This commit is contained in:
thomasraoux 2021-10-21 13:20:52 -07:00
parent 5dc339d982
commit 1d8cc45b0e
4 changed files with 310 additions and 0 deletions

View File

@ -215,6 +215,10 @@ void populateVectorTransposeLoweringPatterns(
RewritePatternSet &patterns,
VectorTransformsOptions options = VectorTransformsOptions());
/// Collect patterns to convert reduction op to vector.contract and fold
/// transpose/broadcast ops into the contract.
void populateVetorReductionToContractPatterns(RewritePatternSet &patterns);
/// Returns the integer type required for subscripts in the vector dialect.
IntegerType getVectorSubscriptType(Builder &builder);

View File

@ -240,6 +240,13 @@ sliceTransferIndices(int64_t index, ArrayRef<int64_t> originalShape,
return slicedIndices;
}
template <typename IntType>
static SmallVector<IntType, 4> extractVector(ArrayAttr arrayAttr) {
return llvm::to_vector<4>(llvm::map_range(
arrayAttr.getAsRange<IntegerAttr>(),
[](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
}
namespace {
struct UnrollTransferReadPattern
@ -1114,6 +1121,193 @@ private:
}
};
/// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp.
/// Ex:
/// ```
/// %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32>
/// %1 = vector.multi_reduction #vector.kind<add>, %0 [1]
/// : vector<8x32x16xf32> to vector<8x16xf32>
/// ```
/// Gets converted to:
/// ```
/// %1 = vector.contract {indexing_maps = [
/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
/// affine_map<(d0, d1, d2) -> (d0, d1)>],
/// iterator_types = ["parallel", "parallel", "reduction"],
/// kind = #vector.kind<add>} %0, %arg1, %cst_f0
/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
/// ```
struct MultiReduceToContract
: public OpRewritePattern<vector::MultiDimReductionOp> {
using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
PatternRewriter &rewriter) const override {
if (reduceOp.kind() != vector::CombiningKind::ADD)
return failure();
Operation *mulOp = reduceOp.source().getDefiningOp();
if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
return failure();
SmallVector<bool> reductionMask = reduceOp.getReductionMask();
auto srcMap = rewriter.getMultiDimIdentityMap(reductionMask.size());
SmallVector<AffineExpr> exprs;
SmallVector<StringRef> iteratorTypes;
for (auto isReduceDim : llvm::enumerate(reductionMask)) {
if (!isReduceDim.value()) {
iteratorTypes.push_back(getParallelIteratorTypeName());
exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index()));
} else {
iteratorTypes.push_back(getReductionIteratorTypeName());
}
}
auto dstMap = AffineMap::get(/*dimCount=*/reductionMask.size(),
/*symCount=*/0, exprs, reduceOp.getContext());
Value zero = rewriter.create<arith::ConstantOp>(
reduceOp.getLoc(), reduceOp.getDestType(),
rewriter.getZeroAttr(reduceOp.getDestType()));
rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>(
reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), zero,
rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}),
rewriter.getStrArrayAttr(iteratorTypes));
return success();
}
};
/// Merge TransposeOp into ContractionOp user.
/// Ex:
/// ```
/// %0 = vector.transpose %arg0, [2, 0, 1]
/// : vector<32x16x8xf32> to vector<8x32x16xf32>
/// %1 = vector.contract {indexing_maps = [
/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
/// affine_map<(d0, d1, d2) -> (d0, d1)>],
/// iterator_types = ["parallel", "parallel", "reduction"],
/// kind = #vector.kind<add>} %0, %arg1, %cst_f0
/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
/// ```
/// Gets converted to:
/// ```
/// %1 = vector.contract {indexing_maps = [
/// affine_map<(d0, d1, d2) -> (d1, d2, d0)>,
/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
/// affine_map<(d0, d1, d2) -> (d0, d1)>],
/// iterator_types = ["parallel", "parallel", "reduction"],
/// kind = #vector.kind<add>} %arg0, %arg1, %cst_f0
/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
/// ```
struct CombineContractTranspose
: public OpRewritePattern<vector::ContractionOp> {
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
PatternRewriter &rewriter) const override {
SmallVector<AffineMap, 4> maps =
llvm::to_vector<4>(contractOp.getIndexingMaps());
Value lhs = contractOp.lhs();
Value rhs = contractOp.rhs();
size_t index = 0;
bool changed = false;
for (Value *operand : {&lhs, &rhs}) {
AffineMap &map = maps[index++];
auto transposeOp = operand->getDefiningOp<vector::TransposeOp>();
if (!transposeOp)
continue;
SmallVector<int64_t> perm;
transposeOp.getTransp(perm);
AffineMap permutationMap = AffineMap::getPermutationMap(
extractVector<unsigned>(transposeOp.transp()),
contractOp.getContext());
map = inversePermutation(permutationMap).compose(map);
*operand = transposeOp.vector();
changed = true;
}
if (!changed)
return failure();
rewriter.replaceOpWithNewOp<vector::ContractionOp>(
contractOp, lhs, rhs, contractOp.acc(),
rewriter.getAffineMapArrayAttr(maps), contractOp.iterator_types());
return success();
}
};
/// Merge BroadcastOp into ContractionOp user.
/// Ex:
/// ```
/// %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32>
/// %1 = vector.contract {indexing_maps = [
/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
/// affine_map<(d0, d1, d2) -> (d0, d1)>],
/// iterator_types = ["parallel", "parallel", "reduction"],
/// kind = #vector.kind<add>} %0, %arg1, %cst_f0
/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
/// ```
/// Gets converted to:
/// ```
/// %1 = vector.contract {indexing_maps = [
/// affine_map<(d0, d1, d2) -> (d1, d2)>,
/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
/// affine_map<(d0, d1, d2) -> (d0, d1)>],
/// iterator_types = ["parallel", "parallel", "reduction"],
/// kind = #vector.kind<add>} %arg0, %arg1, %cst_f0
/// : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
/// ```
struct CombineContractBroadcast
: public OpRewritePattern<vector::ContractionOp> {
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
PatternRewriter &rewriter) const override {
SmallVector<AffineMap, 4> maps =
llvm::to_vector<4>(contractOp.getIndexingMaps());
Value lhs = contractOp.lhs();
Value rhs = contractOp.rhs();
size_t index = 0;
bool changed = false;
for (Value *operand : {&lhs, &rhs}) {
AffineMap &map = maps[index++];
auto broadcast = operand->getDefiningOp<vector::BroadcastOp>();
if (!broadcast)
continue;
// contractionOp can only take vector as operands.
auto srcType = broadcast.getSourceType().dyn_cast<VectorType>();
if (!srcType || srcType.getRank() == broadcast.getVectorType().getRank())
continue;
int64_t rankDiff =
broadcast.getVectorType().getRank() - srcType.getRank();
bool innerDimBroadcast = false;
SmallVector<AffineExpr> originalDims;
for (auto dim : llvm::enumerate(srcType.getShape())) {
if (dim.value() !=
broadcast.getVectorType().getDimSize(rankDiff + dim.index())) {
innerDimBroadcast = true;
break;
}
originalDims.push_back(
rewriter.getAffineDimExpr(dim.index() + rankDiff));
}
// Contract doesn't support inner dimension broadcast. Once this is
// relaxed we can remove this case.
if (innerDimBroadcast)
continue;
AffineMap broadcastMap =
AffineMap::get(broadcast.getVectorType().getRank(), 0, originalDims,
contractOp.getContext());
map = broadcastMap.compose(map);
*operand = broadcast.source();
changed = true;
}
if (!changed)
return failure();
rewriter.replaceOpWithNewOp<vector::ContractionOp>(
contractOp, lhs, rhs, contractOp.acc(),
rewriter.getAffineMapArrayAttr(maps), contractOp.iterator_types());
return success();
}
};
} // namespace
/// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using
@ -3668,6 +3862,12 @@ void mlir::vector::populateVectorTransposeLoweringPatterns(
patterns.add<TransposeOpLowering>(options, patterns.getContext());
}
void mlir::vector::populateVetorReductionToContractPatterns(
RewritePatternSet &patterns) {
patterns.add<MultiReduceToContract, CombineContractBroadcast,
CombineContractTranspose>(patterns.getContext());
}
void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
RewritePatternSet &patterns) {
patterns.add<TransferReadPermutationLowering,

View File

@ -0,0 +1,87 @@
// RUN: mlir-opt %s -test-vector-reduction-to-contract-patterns -split-input-file | FileCheck %s
// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
// CHECK-LABEL: multidimreduction_contract
// CHECK-NEXT: %[[C0:.+]] = arith.constant dense<0.000000e+00> : vector<8x16xf32>
// CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map0]], #[[$map1]]],
// CHECK-SAME: iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>}
// CHECK-SAME: %{{.*}}, %{{.*}}, %[[C0]] : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x16xf32>
// CHECK-NEXT: return %[[R]] : vector<8x16xf32>
func @multidimreduction_contract(
%arg0: vector<8x32x16xf32>,%arg1: vector<8x32x16xf32>) -> vector<8x16xf32> {
%0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32>
%1 = vector.multi_reduction #vector.kind<add>, %0 [1] : vector<8x32x16xf32> to vector<8x16xf32>
return %1 : vector<8x16xf32>
}
// -----
// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
// CHECK-LABEL: multidimreduction_contract_int
// CHECK-NEXT: %[[C0:.+]] = arith.constant dense<0> : vector<8x16xi32>
// CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map0]], #[[$map1]]],
// CHECK-SAME: iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>}
// CHECK-SAME: %{{.*}}, %{{.*}}, %[[C0]] : vector<8x32x16xi32>, vector<8x32x16xi32> into vector<8x16xi32>
// CHECK-NEXT: return %[[R]] : vector<8x16xi32>
func @multidimreduction_contract_int(
%arg0: vector<8x32x16xi32>,%arg1: vector<8x32x16xi32>) -> vector<8x16xi32> {
%0 = arith.muli %arg0, %arg1 : vector<8x32x16xi32>
%1 = vector.multi_reduction #vector.kind<add>, %0 [1] : vector<8x32x16xi32> to vector<8x16xi32>
return %1 : vector<8x16xi32>
}
// -----
#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-LABEL: contract_transpose
// CHECK-SAME: (%[[ARG0:.+]]: vector<32x16x8xf32>,
// CHECK-NEXT: %[[C0:.+]] = arith.constant dense<0.000000e+00> : vector<8x32xf32>
// CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
// CHECK-SAME: %[[ARG0]], %{{.*}}, %[[C0]] : vector<32x16x8xf32>, vector<8x32x16xf32> into vector<8x32xf32>
// CHECK-NEXT: return %[[R]] : vector<8x32xf32>
func @contract_transpose(
%arg0: vector<32x16x8xf32>, %arg1: vector<8x32x16xf32>) -> vector<8x32xf32> {
%cst = arith.constant dense<0.000000e+00> : vector<8x32xf32>
%0 = vector.transpose %arg0, [2, 0, 1] : vector<32x16x8xf32> to vector<8x32x16xf32>
%1 = vector.contract {indexing_maps = [#map0, #map0, #map1],
iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>} %0, %arg1, %cst : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
return %1 : vector<8x32xf32>
}
// -----
#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-LABEL: contract_broadcast
// CHECK-SAME: (%[[ARG0:.+]]: vector<32x16xf32>,
// CHECK-NEXT: %[[C0:.+]] = arith.constant dense<0.000000e+00> : vector<8x32xf32>
// CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
// CHECK-SAME: %[[ARG0]], %{{.*}}, %[[C0]] : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
// CHECK-NEXT: return %[[R]] : vector<8x32xf32>
func @contract_broadcast(
%arg0: vector<32x16xf32>, %arg1: vector<8x32x16xf32>) -> vector<8x32xf32> {
%cst = arith.constant dense<0.000000e+00> : vector<8x32xf32>
%0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32>
%1 = vector.contract {indexing_maps = [#map0, #map0, #map1],
iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>} %0, %arg1, %cst : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
return %1 : vector<8x32xf32>
}

View File

@ -493,6 +493,23 @@ struct TestVectorTransferCollapseInnerMostContiguousDims
}
};
struct TestVectorReduceToContractPatternsPatterns
: public PassWrapper<TestVectorReduceToContractPatternsPatterns,
FunctionPass> {
StringRef getArgument() const final {
return "test-vector-reduction-to-contract-patterns";
}
StringRef getDescription() const final {
return "Test patterns to convert multireduce op to contract and combine "
"broadcast/transpose to contract";
}
void runOnFunction() override {
RewritePatternSet patterns(&getContext());
populateVetorReductionToContractPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
};
} // end anonymous namespace
namespace mlir {
@ -519,6 +536,8 @@ void registerTestVectorConversions() {
PassRegistration<TestVectorMultiReductionLoweringPatterns>();
PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>();
PassRegistration<TestVectorReduceToContractPatternsPatterns>();
}
} // namespace test
} // namespace mlir