[mlir][vector] Extend vector distribution to all elementwise and contract

Uses elementwise interface to generalize canonicalization pattern and add a new
pattern for vector.contract case.

Differential Revision: https://reviews.llvm.org/D104343
This commit is contained in:
thomasraoux 2021-06-30 16:22:31 -07:00
parent 0c400e8953
commit 627733b5f0
5 changed files with 147 additions and 45 deletions

View File

@ -105,6 +105,10 @@ void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
// a sequence of vector.reduction ops. // a sequence of vector.reduction ops.
void populateVectorMultiReductionLoweringPatterns(RewritePatternSet &patterns); void populateVectorMultiReductionLoweringPatterns(RewritePatternSet &patterns);
/// Collect a set of patterns to propagate insert_map/extract_map in the ssa
/// chain.
void populatePropagateVectorDistributionPatterns(RewritePatternSet &patterns);
/// An attribute that specifies the combining function for `vector.contract`, /// An attribute that specifies the combining function for `vector.contract`,
/// and `vector.reduction`. /// and `vector.reduction`.
class CombiningKindAttr class CombiningKindAttr

View File

@ -251,26 +251,6 @@ Optional<DistributeOps>
distributPointwiseVectorOp(OpBuilder &builder, Operation *op, distributPointwiseVectorOp(OpBuilder &builder, Operation *op,
ArrayRef<Value> id, ArrayRef<int64_t> multiplicity, ArrayRef<Value> id, ArrayRef<int64_t> multiplicity,
const AffineMap &map); const AffineMap &map);
/// Canonicalize an extra element using the result of a pointwise operation.
/// Transforms:
/// %v = addf %a, %b : vector32xf32>
/// %dv = vector.extract_map %v, %id, 32 : vector<32xf32> into vector<1xf32>
/// to:
/// %da = vector.extract_map %a, %id, 32 : vector<32xf32> into vector<1xf32>
/// %db = vector.extract_map %a, %id, 32 : vector<32xf32> into vector<1xf32>
/// %dv = addf %da, %db : vector<1xf32>
struct PointwiseExtractPattern : public OpRewritePattern<ExtractMapOp> {
using FilterConstraintType = std::function<LogicalResult(ExtractMapOp op)>;
PointwiseExtractPattern(
MLIRContext *context, FilterConstraintType constraint =
[](ExtractMapOp op) { return success(); })
: OpRewritePattern<ExtractMapOp>(context), filter(constraint) {}
LogicalResult matchAndRewrite(ExtractMapOp extract,
PatternRewriter &rewriter) const override;
private:
FilterConstraintType filter;
};
/// Implements transfer op write to read forwarding and dead transfer write /// Implements transfer op write to read forwarding and dead transfer write
/// optimizations. /// optimizations.

View File

@ -2793,25 +2793,6 @@ LogicalResult mlir::vector::VectorTransferFullPartialRewriter::matchAndRewrite(
return failure(); return failure();
} }
LogicalResult mlir::vector::PointwiseExtractPattern::matchAndRewrite(
ExtractMapOp extract, PatternRewriter &rewriter) const {
Operation *definedOp = extract.vector().getDefiningOp();
if (!definedOp || definedOp->getNumResults() != 1)
return failure();
// TODO: Create an interfaceOp for elementwise operations.
if (!isa<AddFOp>(definedOp))
return failure();
Location loc = extract.getLoc();
SmallVector<Value, 4> extractOperands;
for (OpOperand &operand : definedOp->getOpOperands())
extractOperands.push_back(rewriter.create<vector::ExtractMapOp>(
loc, extract.getResultType(), operand.get(), extract.ids()));
Operation *newOp = cloneOpWithOperandsAndTypes(
rewriter, loc, definedOp, extractOperands, extract.getResult().getType());
rewriter.replaceOp(extract, newOp->getResult(0));
return success();
}
Optional<mlir::vector::DistributeOps> mlir::vector::distributPointwiseVectorOp( Optional<mlir::vector::DistributeOps> mlir::vector::distributPointwiseVectorOp(
OpBuilder &builder, Operation *op, ArrayRef<Value> ids, OpBuilder &builder, Operation *op, ArrayRef<Value> ids,
ArrayRef<int64_t> multiplicity, const AffineMap &map) { ArrayRef<int64_t> multiplicity, const AffineMap &map) {
@ -2843,6 +2824,91 @@ Optional<mlir::vector::DistributeOps> mlir::vector::distributPointwiseVectorOp(
return ops; return ops;
} }
/// Canonicalize an extract_map using the result of a pointwise operation.
/// Transforms:
/// %v = addf %a, %b : vector32xf32>
/// %dv = vector.extract_map %v[%id] : vector<32xf32> to vector<1xf32>
/// to:
/// %da = vector.extract_map %a[%id] : vector<32xf32> to vector<1xf32>
/// %db = vector.extract_map %a[%id] : vector<32xf32> to vector<1xf32>
/// %dv = addf %da, %db : vector<1xf32>
struct PointwiseExtractPattern : public OpRewritePattern<vector::ExtractMapOp> {
using OpRewritePattern<vector::ExtractMapOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ExtractMapOp extract,
PatternRewriter &rewriter) const override {
Operation *definedOp = extract.vector().getDefiningOp();
if (!definedOp || !OpTrait::hasElementwiseMappableTraits(definedOp) ||
definedOp->getNumResults() != 1)
return failure();
Location loc = extract.getLoc();
SmallVector<Value, 4> extractOperands;
for (OpOperand &operand : definedOp->getOpOperands()) {
auto vecType = operand.get().getType().template dyn_cast<VectorType>();
if (!vecType) {
extractOperands.push_back(operand.get());
continue;
}
extractOperands.push_back(rewriter.create<vector::ExtractMapOp>(
loc,
VectorType::get(extract.getResultType().getShape(),
vecType.getElementType()),
operand.get(), extract.ids()));
}
Operation *newOp = cloneOpWithOperandsAndTypes(
rewriter, loc, definedOp, extractOperands, extract.getResultType());
rewriter.replaceOp(extract, newOp->getResult(0));
return success();
}
};
/// Canonicalize an extract_map using the result of a contract operation.
/// This propagate the extract_map to operands.
struct ContractExtractPattern : public OpRewritePattern<vector::ExtractMapOp> {
using OpRewritePattern<vector::ExtractMapOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ExtractMapOp extract,
PatternRewriter &rewriter) const override {
Operation *definedOp = extract.vector().getDefiningOp();
auto contract = dyn_cast_or_null<vector::ContractionOp>(definedOp);
if (!contract)
return failure();
Location loc = contract.getLoc();
unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
AffineMap affineMap = contract.getIndexingMaps()[accIndex];
// Create a map of the dimensions distributed based on the acc affine map.
// Only parallel dimensions are being distributed, reduction dimensions are
// untouched.
DenseMap<int64_t, int64_t> map;
for (unsigned i : llvm::seq(unsigned(0), affineMap.getNumResults()))
map[affineMap.getDimPosition(i)] = extract.getResultType().getDimSize(i);
SmallVector<Value, 4> extractOperands;
for (auto it : llvm::enumerate(contract.getIndexingMaps())) {
// For each operands calculate the new vector type after distribution.
Value operand = contract->getOperand(it.index());
auto vecType = operand.getType().cast<VectorType>();
SmallVector<int64_t> operandShape(vecType.getShape().begin(),
vecType.getShape().end());
for (unsigned i : llvm::seq(unsigned(0), it.value().getNumResults())) {
unsigned dim = it.value().getDimPosition(i);
auto distributedDim = map.find(dim);
// If the dimension is not in the map it means it is a reduction and
// doesn't get distributed.
if (distributedDim == map.end())
continue;
operandShape[i] = distributedDim->second;
}
VectorType newVecType =
VectorType::get(operandShape, vecType.getElementType());
extractOperands.push_back(rewriter.create<vector::ExtractMapOp>(
loc, newVecType, operand, extract.ids()));
}
Operation *newOp =
cloneOpWithOperandsAndTypes(rewriter, loc, definedOp, extractOperands,
extract.getResult().getType());
rewriter.replaceOp(extract, newOp->getResult(0));
return success();
}
};
/// Converts TransferRead op used by ExtractMap op into a smaller dimension /// Converts TransferRead op used by ExtractMap op into a smaller dimension
/// TransferRead. /// TransferRead.
/// Example: /// Example:
@ -4100,8 +4166,7 @@ void mlir::vector::populateVectorMaskMaterializationPatterns(
// TODO: Add this as DRR pattern. // TODO: Add this as DRR pattern.
void mlir::vector::populateVectorToVectorTransformationPatterns( void mlir::vector::populateVectorToVectorTransformationPatterns(
RewritePatternSet &patterns) { RewritePatternSet &patterns) {
patterns.add<ShapeCastOpDecomposer, ShapeCastOpFolder, TupleGetFolderOp, patterns.add<ShapeCastOpDecomposer, ShapeCastOpFolder, TupleGetFolderOp>(
TransferReadExtractPattern, TransferWriteInsertPattern>(
patterns.getContext()); patterns.getContext());
} }
@ -4112,6 +4177,13 @@ void mlir::vector::populateSplitVectorTransferPatterns(
ignoreFilter); ignoreFilter);
} }
void mlir::vector::populatePropagateVectorDistributionPatterns(
RewritePatternSet &patterns) {
patterns.add<PointwiseExtractPattern, ContractExtractPattern,
TransferReadExtractPattern, TransferWriteInsertPattern>(
patterns.getContext());
}
void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns( void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
RewritePatternSet &patterns) { RewritePatternSet &patterns) {
patterns.add<CastAwayExtractStridedSliceLeadingOneDim, patterns.add<CastAwayExtractStridedSliceLeadingOneDim,

View File

@ -1,4 +1,5 @@
// RUN: mlir-opt %s -test-vector-distribute-patterns=distribution-multiplicity=32,1,32 -split-input-file | FileCheck %s // RUN: mlir-opt %s -test-vector-distribute-patterns=distribution-multiplicity=32,1,32 -split-input-file | FileCheck %s
// RUN: mlir-opt %s -test-vector-distribute-patterns=distribution-multiplicity=32,4 -split-input-file | FileCheck %s --check-prefix=CHECK2D
// CHECK-LABEL: func @distribute_vector_add // CHECK-LABEL: func @distribute_vector_add
// CHECK-SAME: (%[[ID:.*]]: index // CHECK-SAME: (%[[ID:.*]]: index
@ -15,6 +16,24 @@ func @distribute_vector_add(%id : index, %A: vector<32xf32>, %B: vector<32xf32>)
// ----- // -----
// CHECK-LABEL: func @distribute_vector_add_exp
// CHECK-SAME: (%[[ID:.*]]: index
// CHECK-NEXT: %[[EXPV:.*]] = math.exp %{{.*}} : vector<32xf32>
// CHECK-NEXT: %[[ADDV:.*]] = addf %[[EXPV]], %{{.*}} : vector<32xf32>
// CHECK-NEXT: %[[EXA:.*]] = vector.extract_map %{{.*}}[%[[ID]]] : vector<32xf32> to vector<1xf32>
// CHECK-NEXT: %[[EXC:.*]] = math.exp %[[EXA]] : vector<1xf32>
// CHECK-NEXT: %[[EXB:.*]] = vector.extract_map %{{.*}}[%[[ID]]] : vector<32xf32> to vector<1xf32>
// CHECK-NEXT: %[[ADD:.*]] = addf %[[EXC]], %[[EXB]] : vector<1xf32>
// CHECK-NEXT: %[[INS:.*]] = vector.insert_map %[[ADD]], %[[ADDV]][%[[ID]]] : vector<1xf32> into vector<32xf32>
// CHECK-NEXT: return %[[INS]] : vector<32xf32>
func @distribute_vector_add_exp(%id : index, %A: vector<32xf32>, %B: vector<32xf32>) -> vector<32xf32> {
%C = math.exp %A : vector<32xf32>
%0 = addf %C, %B : vector<32xf32>
return %0: vector<32xf32>
}
// -----
// CHECK-LABEL: func @vector_add_read_write // CHECK-LABEL: func @vector_add_read_write
// CHECK-SAME: (%[[ID:.*]]: index // CHECK-SAME: (%[[ID:.*]]: index
// CHECK: %[[EXA:.*]] = vector.transfer_read %{{.*}}[%[[ID]]], %{{.*}} : memref<32xf32>, vector<1xf32> // CHECK: %[[EXA:.*]] = vector.transfer_read %{{.*}}[%[[ID]]], %{{.*}} : memref<32xf32>, vector<1xf32>
@ -154,3 +173,32 @@ func @vector_add_transfer_permutation(%id0 : index, %id1 : index, %A: memref<?x?
vector.transfer_write %acc, %C[%c0, %c0, %c0, %c0] {permutation_map = #map2}: vector<64x4x32xf32>, memref<?x?x?x?xf32> vector.transfer_write %acc, %C[%c0, %c0, %c0, %c0] {permutation_map = #map2}: vector<64x4x32xf32>, memref<?x?x?x?xf32>
return return
} }
// -----
// CHECK2D-LABEL: vector_add_contract
// CHECK2D: %[[A:.+]] = vector.transfer_read %arg2[%0, %c0], %cst : memref<?x?xf32>, vector<2x4xf32>
// CHECK2D: %[[B:.+]] = vector.transfer_read %arg3[%2, %c0], %cst : memref<?x?xf32>, vector<16x4xf32>
// CHECK2D: %[[C:.+]] = vector.transfer_read %arg4[%4, %5], %cst : memref<?x?xf32>, vector<2x16xf32>
// CHECK2D: %[[E:.+]] = vector.transfer_read %arg5[%7, %8], %cst : memref<?x?xf32>, vector<2x16xf32>
// CHECK2D: %[[D:.+]] = vector.contract {{.*}} %[[A]], %[[B]], %[[C]] : vector<2x4xf32>, vector<16x4xf32> into vector<2x16xf32>
// CHECK2D: %[[R:.+]] = addf %[[D]], %[[E]] : vector<2x16xf32>
// CHECK2D: vector.transfer_write %[[R]], {{.*}} : vector<2x16xf32>, memref<?x?xf32>
func @vector_add_contract(%id0 : index, %id1 : index, %A: memref<?x?xf32>,
%B: memref<?x?xf32>, %C: memref<?x?xf32>, %D: memref<?x?xf32>) {
%c0 = constant 0 : index
%cf0 = constant 0.0 : f32
%a = vector.transfer_read %A[%c0, %c0], %cf0 : memref<?x?xf32>, vector<64x4xf32>
%b = vector.transfer_read %B[%c0, %c0], %cf0 : memref<?x?xf32>, vector<64x4xf32>
%c = vector.transfer_read %C[%c0, %c0], %cf0 : memref<?x?xf32>, vector<64x64xf32>
%d = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
affine_map<(d0, d1, d2) -> (d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>}
%a, %b, %c : vector<64x4xf32>, vector<64x4xf32> into vector<64x64xf32>
%e = vector.transfer_read %D[%c0, %c0], %cf0 : memref<?x?xf32>, vector<64x64xf32>
%r = addf %d, %e : vector<64x64xf32>
vector.transfer_write %r, %C[%c0, %c0] : vector<64x64xf32>, memref<?x?xf32>
return
}

View File

@ -275,8 +275,7 @@ struct TestVectorDistributePatterns
} }
} }
}); });
patterns.add<PointwiseExtractPattern>(ctx); populatePropagateVectorDistributionPatterns(patterns);
populateVectorToVectorTransformationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
} }
}; };
@ -339,8 +338,7 @@ struct TestVectorToLoopPatterns
} }
return mlir::WalkResult::interrupt(); return mlir::WalkResult::interrupt();
}); });
patterns.add<PointwiseExtractPattern>(ctx); populatePropagateVectorDistributionPatterns(patterns);
populateVectorToVectorTransformationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
} }
}; };