forked from OSchip/llvm-project
[mlir][vector] Add patterns to convert multidimreduce to vector.contract
add several patterns that will simplify contraction vectorization in the future. With those canonicalizationns we will be able to remove the special case for contration during vectorization and rely on those transformations to avoid materizalizing broadcast ops. Differential Revision: https://reviews.llvm.org/D112121
This commit is contained in:
parent
5dc339d982
commit
1d8cc45b0e
|
@ -215,6 +215,10 @@ void populateVectorTransposeLoweringPatterns(
|
|||
RewritePatternSet &patterns,
|
||||
VectorTransformsOptions options = VectorTransformsOptions());
|
||||
|
||||
/// Collect patterns to convert reduction op to vector.contract and fold
|
||||
/// transpose/broadcast ops into the contract.
|
||||
void populateVetorReductionToContractPatterns(RewritePatternSet &patterns);
|
||||
|
||||
/// Returns the integer type required for subscripts in the vector dialect.
|
||||
IntegerType getVectorSubscriptType(Builder &builder);
|
||||
|
||||
|
|
|
@ -240,6 +240,13 @@ sliceTransferIndices(int64_t index, ArrayRef<int64_t> originalShape,
|
|||
return slicedIndices;
|
||||
}
|
||||
|
||||
template <typename IntType>
|
||||
static SmallVector<IntType, 4> extractVector(ArrayAttr arrayAttr) {
|
||||
return llvm::to_vector<4>(llvm::map_range(
|
||||
arrayAttr.getAsRange<IntegerAttr>(),
|
||||
[](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
struct UnrollTransferReadPattern
|
||||
|
@ -1114,6 +1121,193 @@ private:
|
|||
}
|
||||
};
|
||||
|
||||
/// Convert MulIOp/MulFOp + MultiDimReductionOp<add> into ContractionOp.
|
||||
/// Ex:
|
||||
/// ```
|
||||
/// %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32>
|
||||
/// %1 = vector.multi_reduction #vector.kind<add>, %0 [1]
|
||||
/// : vector<8x32x16xf32> to vector<8x16xf32>
|
||||
/// ```
|
||||
/// Gets converted to:
|
||||
/// ```
|
||||
/// %1 = vector.contract {indexing_maps = [
|
||||
/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
|
||||
/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
|
||||
/// affine_map<(d0, d1, d2) -> (d0, d1)>],
|
||||
/// iterator_types = ["parallel", "parallel", "reduction"],
|
||||
/// kind = #vector.kind<add>} %0, %arg1, %cst_f0
|
||||
/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
|
||||
/// ```
|
||||
struct MultiReduceToContract
|
||||
: public OpRewritePattern<vector::MultiDimReductionOp> {
|
||||
using OpRewritePattern<vector::MultiDimReductionOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (reduceOp.kind() != vector::CombiningKind::ADD)
|
||||
return failure();
|
||||
Operation *mulOp = reduceOp.source().getDefiningOp();
|
||||
if (!mulOp || !isa<arith::MulIOp, arith::MulFOp>(mulOp))
|
||||
return failure();
|
||||
SmallVector<bool> reductionMask = reduceOp.getReductionMask();
|
||||
auto srcMap = rewriter.getMultiDimIdentityMap(reductionMask.size());
|
||||
SmallVector<AffineExpr> exprs;
|
||||
SmallVector<StringRef> iteratorTypes;
|
||||
for (auto isReduceDim : llvm::enumerate(reductionMask)) {
|
||||
if (!isReduceDim.value()) {
|
||||
iteratorTypes.push_back(getParallelIteratorTypeName());
|
||||
exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index()));
|
||||
} else {
|
||||
iteratorTypes.push_back(getReductionIteratorTypeName());
|
||||
}
|
||||
}
|
||||
auto dstMap = AffineMap::get(/*dimCount=*/reductionMask.size(),
|
||||
/*symCount=*/0, exprs, reduceOp.getContext());
|
||||
Value zero = rewriter.create<arith::ConstantOp>(
|
||||
reduceOp.getLoc(), reduceOp.getDestType(),
|
||||
rewriter.getZeroAttr(reduceOp.getDestType()));
|
||||
rewriter.replaceOpWithNewOp<mlir::vector::ContractionOp>(
|
||||
reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), zero,
|
||||
rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}),
|
||||
rewriter.getStrArrayAttr(iteratorTypes));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Merge TransposeOp into ContractionOp user.
|
||||
/// Ex:
|
||||
/// ```
|
||||
/// %0 = vector.transpose %arg0, [2, 0, 1]
|
||||
/// : vector<32x16x8xf32> to vector<8x32x16xf32>
|
||||
/// %1 = vector.contract {indexing_maps = [
|
||||
/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
|
||||
/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
|
||||
/// affine_map<(d0, d1, d2) -> (d0, d1)>],
|
||||
/// iterator_types = ["parallel", "parallel", "reduction"],
|
||||
/// kind = #vector.kind<add>} %0, %arg1, %cst_f0
|
||||
/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
|
||||
/// ```
|
||||
/// Gets converted to:
|
||||
/// ```
|
||||
/// %1 = vector.contract {indexing_maps = [
|
||||
/// affine_map<(d0, d1, d2) -> (d1, d2, d0)>,
|
||||
/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
|
||||
/// affine_map<(d0, d1, d2) -> (d0, d1)>],
|
||||
/// iterator_types = ["parallel", "parallel", "reduction"],
|
||||
/// kind = #vector.kind<add>} %arg0, %arg1, %cst_f0
|
||||
/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
|
||||
/// ```
|
||||
struct CombineContractTranspose
|
||||
: public OpRewritePattern<vector::ContractionOp> {
|
||||
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
SmallVector<AffineMap, 4> maps =
|
||||
llvm::to_vector<4>(contractOp.getIndexingMaps());
|
||||
Value lhs = contractOp.lhs();
|
||||
Value rhs = contractOp.rhs();
|
||||
size_t index = 0;
|
||||
bool changed = false;
|
||||
for (Value *operand : {&lhs, &rhs}) {
|
||||
AffineMap &map = maps[index++];
|
||||
auto transposeOp = operand->getDefiningOp<vector::TransposeOp>();
|
||||
if (!transposeOp)
|
||||
continue;
|
||||
SmallVector<int64_t> perm;
|
||||
transposeOp.getTransp(perm);
|
||||
AffineMap permutationMap = AffineMap::getPermutationMap(
|
||||
extractVector<unsigned>(transposeOp.transp()),
|
||||
contractOp.getContext());
|
||||
map = inversePermutation(permutationMap).compose(map);
|
||||
*operand = transposeOp.vector();
|
||||
changed = true;
|
||||
}
|
||||
if (!changed)
|
||||
return failure();
|
||||
rewriter.replaceOpWithNewOp<vector::ContractionOp>(
|
||||
contractOp, lhs, rhs, contractOp.acc(),
|
||||
rewriter.getAffineMapArrayAttr(maps), contractOp.iterator_types());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Merge BroadcastOp into ContractionOp user.
|
||||
/// Ex:
|
||||
/// ```
|
||||
/// %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32>
|
||||
/// %1 = vector.contract {indexing_maps = [
|
||||
/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
|
||||
/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
|
||||
/// affine_map<(d0, d1, d2) -> (d0, d1)>],
|
||||
/// iterator_types = ["parallel", "parallel", "reduction"],
|
||||
/// kind = #vector.kind<add>} %0, %arg1, %cst_f0
|
||||
/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
|
||||
/// ```
|
||||
/// Gets converted to:
|
||||
/// ```
|
||||
/// %1 = vector.contract {indexing_maps = [
|
||||
/// affine_map<(d0, d1, d2) -> (d1, d2)>,
|
||||
/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
|
||||
/// affine_map<(d0, d1, d2) -> (d0, d1)>],
|
||||
/// iterator_types = ["parallel", "parallel", "reduction"],
|
||||
/// kind = #vector.kind<add>} %arg0, %arg1, %cst_f0
|
||||
/// : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
|
||||
/// ```
|
||||
struct CombineContractBroadcast
|
||||
: public OpRewritePattern<vector::ContractionOp> {
|
||||
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
SmallVector<AffineMap, 4> maps =
|
||||
llvm::to_vector<4>(contractOp.getIndexingMaps());
|
||||
Value lhs = contractOp.lhs();
|
||||
Value rhs = contractOp.rhs();
|
||||
size_t index = 0;
|
||||
bool changed = false;
|
||||
for (Value *operand : {&lhs, &rhs}) {
|
||||
AffineMap &map = maps[index++];
|
||||
auto broadcast = operand->getDefiningOp<vector::BroadcastOp>();
|
||||
if (!broadcast)
|
||||
continue;
|
||||
// contractionOp can only take vector as operands.
|
||||
auto srcType = broadcast.getSourceType().dyn_cast<VectorType>();
|
||||
if (!srcType || srcType.getRank() == broadcast.getVectorType().getRank())
|
||||
continue;
|
||||
int64_t rankDiff =
|
||||
broadcast.getVectorType().getRank() - srcType.getRank();
|
||||
bool innerDimBroadcast = false;
|
||||
SmallVector<AffineExpr> originalDims;
|
||||
for (auto dim : llvm::enumerate(srcType.getShape())) {
|
||||
if (dim.value() !=
|
||||
broadcast.getVectorType().getDimSize(rankDiff + dim.index())) {
|
||||
innerDimBroadcast = true;
|
||||
break;
|
||||
}
|
||||
originalDims.push_back(
|
||||
rewriter.getAffineDimExpr(dim.index() + rankDiff));
|
||||
}
|
||||
// Contract doesn't support inner dimension broadcast. Once this is
|
||||
// relaxed we can remove this case.
|
||||
if (innerDimBroadcast)
|
||||
continue;
|
||||
AffineMap broadcastMap =
|
||||
AffineMap::get(broadcast.getVectorType().getRank(), 0, originalDims,
|
||||
contractOp.getContext());
|
||||
map = broadcastMap.compose(map);
|
||||
*operand = broadcast.source();
|
||||
changed = true;
|
||||
}
|
||||
if (!changed)
|
||||
return failure();
|
||||
rewriter.replaceOpWithNewOp<vector::ContractionOp>(
|
||||
contractOp, lhs, rhs, contractOp.acc(),
|
||||
rewriter.getAffineMapArrayAttr(maps), contractOp.iterator_types());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
/// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using
|
||||
|
@ -3668,6 +3862,12 @@ void mlir::vector::populateVectorTransposeLoweringPatterns(
|
|||
patterns.add<TransposeOpLowering>(options, patterns.getContext());
|
||||
}
|
||||
|
||||
void mlir::vector::populateVetorReductionToContractPatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<MultiReduceToContract, CombineContractBroadcast,
|
||||
CombineContractTranspose>(patterns.getContext());
|
||||
}
|
||||
|
||||
void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<TransferReadPermutationLowering,
|
||||
|
|
|
@ -0,0 +1,87 @@
|
|||
// RUN: mlir-opt %s -test-vector-reduction-to-contract-patterns -split-input-file | FileCheck %s
|
||||
|
||||
// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
|
||||
|
||||
// CHECK-LABEL: multidimreduction_contract
|
||||
// CHECK-NEXT: %[[C0:.+]] = arith.constant dense<0.000000e+00> : vector<8x16xf32>
|
||||
// CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map0]], #[[$map1]]],
|
||||
// CHECK-SAME: iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>}
|
||||
// CHECK-SAME: %{{.*}}, %{{.*}}, %[[C0]] : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x16xf32>
|
||||
// CHECK-NEXT: return %[[R]] : vector<8x16xf32>
|
||||
func @multidimreduction_contract(
|
||||
%arg0: vector<8x32x16xf32>,%arg1: vector<8x32x16xf32>) -> vector<8x16xf32> {
|
||||
%0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32>
|
||||
%1 = vector.multi_reduction #vector.kind<add>, %0 [1] : vector<8x32x16xf32> to vector<8x16xf32>
|
||||
return %1 : vector<8x16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
|
||||
|
||||
// CHECK-LABEL: multidimreduction_contract_int
|
||||
// CHECK-NEXT: %[[C0:.+]] = arith.constant dense<0> : vector<8x16xi32>
|
||||
// CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map0]], #[[$map1]]],
|
||||
// CHECK-SAME: iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>}
|
||||
// CHECK-SAME: %{{.*}}, %{{.*}}, %[[C0]] : vector<8x32x16xi32>, vector<8x32x16xi32> into vector<8x16xi32>
|
||||
// CHECK-NEXT: return %[[R]] : vector<8x16xi32>
|
||||
func @multidimreduction_contract_int(
|
||||
%arg0: vector<8x32x16xi32>,%arg1: vector<8x32x16xi32>) -> vector<8x16xi32> {
|
||||
%0 = arith.muli %arg0, %arg1 : vector<8x32x16xi32>
|
||||
%1 = vector.multi_reduction #vector.kind<add>, %0 [1] : vector<8x32x16xi32> to vector<8x16xi32>
|
||||
return %1 : vector<8x16xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
|
||||
// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
|
||||
// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
|
||||
// CHECK-LABEL: contract_transpose
|
||||
// CHECK-SAME: (%[[ARG0:.+]]: vector<32x16x8xf32>,
|
||||
// CHECK-NEXT: %[[C0:.+]] = arith.constant dense<0.000000e+00> : vector<8x32xf32>
|
||||
// CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]],
|
||||
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
|
||||
// CHECK-SAME: %[[ARG0]], %{{.*}}, %[[C0]] : vector<32x16x8xf32>, vector<8x32x16xf32> into vector<8x32xf32>
|
||||
// CHECK-NEXT: return %[[R]] : vector<8x32xf32>
|
||||
func @contract_transpose(
|
||||
%arg0: vector<32x16x8xf32>, %arg1: vector<8x32x16xf32>) -> vector<8x32xf32> {
|
||||
%cst = arith.constant dense<0.000000e+00> : vector<8x32xf32>
|
||||
%0 = vector.transpose %arg0, [2, 0, 1] : vector<32x16x8xf32> to vector<8x32x16xf32>
|
||||
%1 = vector.contract {indexing_maps = [#map0, #map0, #map1],
|
||||
iterator_types = ["parallel", "parallel", "reduction"],
|
||||
kind = #vector.kind<add>} %0, %arg1, %cst : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
|
||||
return %1 : vector<8x32xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
|
||||
// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
|
||||
// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
|
||||
// CHECK-LABEL: contract_broadcast
|
||||
// CHECK-SAME: (%[[ARG0:.+]]: vector<32x16xf32>,
|
||||
// CHECK-NEXT: %[[C0:.+]] = arith.constant dense<0.000000e+00> : vector<8x32xf32>
|
||||
// CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]],
|
||||
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
|
||||
// CHECK-SAME: %[[ARG0]], %{{.*}}, %[[C0]] : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
|
||||
// CHECK-NEXT: return %[[R]] : vector<8x32xf32>
|
||||
func @contract_broadcast(
|
||||
%arg0: vector<32x16xf32>, %arg1: vector<8x32x16xf32>) -> vector<8x32xf32> {
|
||||
%cst = arith.constant dense<0.000000e+00> : vector<8x32xf32>
|
||||
%0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32>
|
||||
%1 = vector.contract {indexing_maps = [#map0, #map0, #map1],
|
||||
iterator_types = ["parallel", "parallel", "reduction"],
|
||||
kind = #vector.kind<add>} %0, %arg1, %cst : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
|
||||
return %1 : vector<8x32xf32>
|
||||
}
|
|
@ -493,6 +493,23 @@ struct TestVectorTransferCollapseInnerMostContiguousDims
|
|||
}
|
||||
};
|
||||
|
||||
struct TestVectorReduceToContractPatternsPatterns
|
||||
: public PassWrapper<TestVectorReduceToContractPatternsPatterns,
|
||||
FunctionPass> {
|
||||
StringRef getArgument() const final {
|
||||
return "test-vector-reduction-to-contract-patterns";
|
||||
}
|
||||
StringRef getDescription() const final {
|
||||
return "Test patterns to convert multireduce op to contract and combine "
|
||||
"broadcast/transpose to contract";
|
||||
}
|
||||
void runOnFunction() override {
|
||||
RewritePatternSet patterns(&getContext());
|
||||
populateVetorReductionToContractPatterns(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
|
||||
}
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
namespace mlir {
|
||||
|
@ -519,6 +536,8 @@ void registerTestVectorConversions() {
|
|||
PassRegistration<TestVectorMultiReductionLoweringPatterns>();
|
||||
|
||||
PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>();
|
||||
|
||||
PassRegistration<TestVectorReduceToContractPatternsPatterns>();
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace mlir
|
||||
|
|
Loading…
Reference in New Issue