forked from OSchip/llvm-project
[mlir][vector] Extend vector distribution to all elementwise and contract
Uses elementwise interface to generalize canonicalization pattern and add a new pattern for vector.contract case. Differential Revision: https://reviews.llvm.org/D104343
This commit is contained in:
parent
0c400e8953
commit
627733b5f0
|
@ -105,6 +105,10 @@ void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
|
|||
// a sequence of vector.reduction ops.
|
||||
void populateVectorMultiReductionLoweringPatterns(RewritePatternSet &patterns);
|
||||
|
||||
/// Collect a set of patterns to propagate insert_map/extract_map in the ssa
|
||||
/// chain.
|
||||
void populatePropagateVectorDistributionPatterns(RewritePatternSet &patterns);
|
||||
|
||||
/// An attribute that specifies the combining function for `vector.contract`,
|
||||
/// and `vector.reduction`.
|
||||
class CombiningKindAttr
|
||||
|
|
|
@ -251,26 +251,6 @@ Optional<DistributeOps>
|
|||
distributPointwiseVectorOp(OpBuilder &builder, Operation *op,
|
||||
ArrayRef<Value> id, ArrayRef<int64_t> multiplicity,
|
||||
const AffineMap &map);
|
||||
/// Canonicalize an extra element using the result of a pointwise operation.
|
||||
/// Transforms:
|
||||
/// %v = addf %a, %b : vector32xf32>
|
||||
/// %dv = vector.extract_map %v, %id, 32 : vector<32xf32> into vector<1xf32>
|
||||
/// to:
|
||||
/// %da = vector.extract_map %a, %id, 32 : vector<32xf32> into vector<1xf32>
|
||||
/// %db = vector.extract_map %a, %id, 32 : vector<32xf32> into vector<1xf32>
|
||||
/// %dv = addf %da, %db : vector<1xf32>
|
||||
struct PointwiseExtractPattern : public OpRewritePattern<ExtractMapOp> {
|
||||
using FilterConstraintType = std::function<LogicalResult(ExtractMapOp op)>;
|
||||
PointwiseExtractPattern(
|
||||
MLIRContext *context, FilterConstraintType constraint =
|
||||
[](ExtractMapOp op) { return success(); })
|
||||
: OpRewritePattern<ExtractMapOp>(context), filter(constraint) {}
|
||||
LogicalResult matchAndRewrite(ExtractMapOp extract,
|
||||
PatternRewriter &rewriter) const override;
|
||||
|
||||
private:
|
||||
FilterConstraintType filter;
|
||||
};
|
||||
|
||||
/// Implements transfer op write to read forwarding and dead transfer write
|
||||
/// optimizations.
|
||||
|
|
|
@ -2793,25 +2793,6 @@ LogicalResult mlir::vector::VectorTransferFullPartialRewriter::matchAndRewrite(
|
|||
return failure();
|
||||
}
|
||||
|
||||
LogicalResult mlir::vector::PointwiseExtractPattern::matchAndRewrite(
|
||||
ExtractMapOp extract, PatternRewriter &rewriter) const {
|
||||
Operation *definedOp = extract.vector().getDefiningOp();
|
||||
if (!definedOp || definedOp->getNumResults() != 1)
|
||||
return failure();
|
||||
// TODO: Create an interfaceOp for elementwise operations.
|
||||
if (!isa<AddFOp>(definedOp))
|
||||
return failure();
|
||||
Location loc = extract.getLoc();
|
||||
SmallVector<Value, 4> extractOperands;
|
||||
for (OpOperand &operand : definedOp->getOpOperands())
|
||||
extractOperands.push_back(rewriter.create<vector::ExtractMapOp>(
|
||||
loc, extract.getResultType(), operand.get(), extract.ids()));
|
||||
Operation *newOp = cloneOpWithOperandsAndTypes(
|
||||
rewriter, loc, definedOp, extractOperands, extract.getResult().getType());
|
||||
rewriter.replaceOp(extract, newOp->getResult(0));
|
||||
return success();
|
||||
}
|
||||
|
||||
Optional<mlir::vector::DistributeOps> mlir::vector::distributPointwiseVectorOp(
|
||||
OpBuilder &builder, Operation *op, ArrayRef<Value> ids,
|
||||
ArrayRef<int64_t> multiplicity, const AffineMap &map) {
|
||||
|
@ -2843,6 +2824,91 @@ Optional<mlir::vector::DistributeOps> mlir::vector::distributPointwiseVectorOp(
|
|||
return ops;
|
||||
}
|
||||
|
||||
/// Canonicalize an extract_map using the result of a pointwise operation.
|
||||
/// Transforms:
|
||||
/// %v = addf %a, %b : vector32xf32>
|
||||
/// %dv = vector.extract_map %v[%id] : vector<32xf32> to vector<1xf32>
|
||||
/// to:
|
||||
/// %da = vector.extract_map %a[%id] : vector<32xf32> to vector<1xf32>
|
||||
/// %db = vector.extract_map %a[%id] : vector<32xf32> to vector<1xf32>
|
||||
/// %dv = addf %da, %db : vector<1xf32>
|
||||
struct PointwiseExtractPattern : public OpRewritePattern<vector::ExtractMapOp> {
|
||||
using OpRewritePattern<vector::ExtractMapOp>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(vector::ExtractMapOp extract,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Operation *definedOp = extract.vector().getDefiningOp();
|
||||
if (!definedOp || !OpTrait::hasElementwiseMappableTraits(definedOp) ||
|
||||
definedOp->getNumResults() != 1)
|
||||
return failure();
|
||||
Location loc = extract.getLoc();
|
||||
SmallVector<Value, 4> extractOperands;
|
||||
for (OpOperand &operand : definedOp->getOpOperands()) {
|
||||
auto vecType = operand.get().getType().template dyn_cast<VectorType>();
|
||||
if (!vecType) {
|
||||
extractOperands.push_back(operand.get());
|
||||
continue;
|
||||
}
|
||||
extractOperands.push_back(rewriter.create<vector::ExtractMapOp>(
|
||||
loc,
|
||||
VectorType::get(extract.getResultType().getShape(),
|
||||
vecType.getElementType()),
|
||||
operand.get(), extract.ids()));
|
||||
}
|
||||
Operation *newOp = cloneOpWithOperandsAndTypes(
|
||||
rewriter, loc, definedOp, extractOperands, extract.getResultType());
|
||||
rewriter.replaceOp(extract, newOp->getResult(0));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Canonicalize an extract_map using the result of a contract operation.
|
||||
/// This propagate the extract_map to operands.
|
||||
struct ContractExtractPattern : public OpRewritePattern<vector::ExtractMapOp> {
|
||||
using OpRewritePattern<vector::ExtractMapOp>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(vector::ExtractMapOp extract,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Operation *definedOp = extract.vector().getDefiningOp();
|
||||
auto contract = dyn_cast_or_null<vector::ContractionOp>(definedOp);
|
||||
if (!contract)
|
||||
return failure();
|
||||
Location loc = contract.getLoc();
|
||||
unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
|
||||
AffineMap affineMap = contract.getIndexingMaps()[accIndex];
|
||||
// Create a map of the dimensions distributed based on the acc affine map.
|
||||
// Only parallel dimensions are being distributed, reduction dimensions are
|
||||
// untouched.
|
||||
DenseMap<int64_t, int64_t> map;
|
||||
for (unsigned i : llvm::seq(unsigned(0), affineMap.getNumResults()))
|
||||
map[affineMap.getDimPosition(i)] = extract.getResultType().getDimSize(i);
|
||||
SmallVector<Value, 4> extractOperands;
|
||||
for (auto it : llvm::enumerate(contract.getIndexingMaps())) {
|
||||
// For each operands calculate the new vector type after distribution.
|
||||
Value operand = contract->getOperand(it.index());
|
||||
auto vecType = operand.getType().cast<VectorType>();
|
||||
SmallVector<int64_t> operandShape(vecType.getShape().begin(),
|
||||
vecType.getShape().end());
|
||||
for (unsigned i : llvm::seq(unsigned(0), it.value().getNumResults())) {
|
||||
unsigned dim = it.value().getDimPosition(i);
|
||||
auto distributedDim = map.find(dim);
|
||||
// If the dimension is not in the map it means it is a reduction and
|
||||
// doesn't get distributed.
|
||||
if (distributedDim == map.end())
|
||||
continue;
|
||||
operandShape[i] = distributedDim->second;
|
||||
}
|
||||
VectorType newVecType =
|
||||
VectorType::get(operandShape, vecType.getElementType());
|
||||
extractOperands.push_back(rewriter.create<vector::ExtractMapOp>(
|
||||
loc, newVecType, operand, extract.ids()));
|
||||
}
|
||||
Operation *newOp =
|
||||
cloneOpWithOperandsAndTypes(rewriter, loc, definedOp, extractOperands,
|
||||
extract.getResult().getType());
|
||||
rewriter.replaceOp(extract, newOp->getResult(0));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Converts TransferRead op used by ExtractMap op into a smaller dimension
|
||||
/// TransferRead.
|
||||
/// Example:
|
||||
|
@ -4100,8 +4166,7 @@ void mlir::vector::populateVectorMaskMaterializationPatterns(
|
|||
// TODO: Add this as DRR pattern.
|
||||
void mlir::vector::populateVectorToVectorTransformationPatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<ShapeCastOpDecomposer, ShapeCastOpFolder, TupleGetFolderOp,
|
||||
TransferReadExtractPattern, TransferWriteInsertPattern>(
|
||||
patterns.add<ShapeCastOpDecomposer, ShapeCastOpFolder, TupleGetFolderOp>(
|
||||
patterns.getContext());
|
||||
}
|
||||
|
||||
|
@ -4112,6 +4177,13 @@ void mlir::vector::populateSplitVectorTransferPatterns(
|
|||
ignoreFilter);
|
||||
}
|
||||
|
||||
void mlir::vector::populatePropagateVectorDistributionPatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<PointwiseExtractPattern, ContractExtractPattern,
|
||||
TransferReadExtractPattern, TransferWriteInsertPattern>(
|
||||
patterns.getContext());
|
||||
}
|
||||
|
||||
void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<CastAwayExtractStridedSliceLeadingOneDim,
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
// RUN: mlir-opt %s -test-vector-distribute-patterns=distribution-multiplicity=32,1,32 -split-input-file | FileCheck %s
|
||||
// RUN: mlir-opt %s -test-vector-distribute-patterns=distribution-multiplicity=32,4 -split-input-file | FileCheck %s --check-prefix=CHECK2D
|
||||
|
||||
// CHECK-LABEL: func @distribute_vector_add
|
||||
// CHECK-SAME: (%[[ID:.*]]: index
|
||||
|
@ -15,6 +16,24 @@ func @distribute_vector_add(%id : index, %A: vector<32xf32>, %B: vector<32xf32>)
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @distribute_vector_add_exp
|
||||
// CHECK-SAME: (%[[ID:.*]]: index
|
||||
// CHECK-NEXT: %[[EXPV:.*]] = math.exp %{{.*}} : vector<32xf32>
|
||||
// CHECK-NEXT: %[[ADDV:.*]] = addf %[[EXPV]], %{{.*}} : vector<32xf32>
|
||||
// CHECK-NEXT: %[[EXA:.*]] = vector.extract_map %{{.*}}[%[[ID]]] : vector<32xf32> to vector<1xf32>
|
||||
// CHECK-NEXT: %[[EXC:.*]] = math.exp %[[EXA]] : vector<1xf32>
|
||||
// CHECK-NEXT: %[[EXB:.*]] = vector.extract_map %{{.*}}[%[[ID]]] : vector<32xf32> to vector<1xf32>
|
||||
// CHECK-NEXT: %[[ADD:.*]] = addf %[[EXC]], %[[EXB]] : vector<1xf32>
|
||||
// CHECK-NEXT: %[[INS:.*]] = vector.insert_map %[[ADD]], %[[ADDV]][%[[ID]]] : vector<1xf32> into vector<32xf32>
|
||||
// CHECK-NEXT: return %[[INS]] : vector<32xf32>
|
||||
func @distribute_vector_add_exp(%id : index, %A: vector<32xf32>, %B: vector<32xf32>) -> vector<32xf32> {
|
||||
%C = math.exp %A : vector<32xf32>
|
||||
%0 = addf %C, %B : vector<32xf32>
|
||||
return %0: vector<32xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @vector_add_read_write
|
||||
// CHECK-SAME: (%[[ID:.*]]: index
|
||||
// CHECK: %[[EXA:.*]] = vector.transfer_read %{{.*}}[%[[ID]]], %{{.*}} : memref<32xf32>, vector<1xf32>
|
||||
|
@ -154,3 +173,32 @@ func @vector_add_transfer_permutation(%id0 : index, %id1 : index, %A: memref<?x?
|
|||
vector.transfer_write %acc, %C[%c0, %c0, %c0, %c0] {permutation_map = #map2}: vector<64x4x32xf32>, memref<?x?x?x?xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK2D-LABEL: vector_add_contract
|
||||
// CHECK2D: %[[A:.+]] = vector.transfer_read %arg2[%0, %c0], %cst : memref<?x?xf32>, vector<2x4xf32>
|
||||
// CHECK2D: %[[B:.+]] = vector.transfer_read %arg3[%2, %c0], %cst : memref<?x?xf32>, vector<16x4xf32>
|
||||
// CHECK2D: %[[C:.+]] = vector.transfer_read %arg4[%4, %5], %cst : memref<?x?xf32>, vector<2x16xf32>
|
||||
// CHECK2D: %[[E:.+]] = vector.transfer_read %arg5[%7, %8], %cst : memref<?x?xf32>, vector<2x16xf32>
|
||||
// CHECK2D: %[[D:.+]] = vector.contract {{.*}} %[[A]], %[[B]], %[[C]] : vector<2x4xf32>, vector<16x4xf32> into vector<2x16xf32>
|
||||
// CHECK2D: %[[R:.+]] = addf %[[D]], %[[E]] : vector<2x16xf32>
|
||||
// CHECK2D: vector.transfer_write %[[R]], {{.*}} : vector<2x16xf32>, memref<?x?xf32>
|
||||
func @vector_add_contract(%id0 : index, %id1 : index, %A: memref<?x?xf32>,
|
||||
%B: memref<?x?xf32>, %C: memref<?x?xf32>, %D: memref<?x?xf32>) {
|
||||
%c0 = constant 0 : index
|
||||
%cf0 = constant 0.0 : f32
|
||||
%a = vector.transfer_read %A[%c0, %c0], %cf0 : memref<?x?xf32>, vector<64x4xf32>
|
||||
%b = vector.transfer_read %B[%c0, %c0], %cf0 : memref<?x?xf32>, vector<64x4xf32>
|
||||
%c = vector.transfer_read %C[%c0, %c0], %cf0 : memref<?x?xf32>, vector<64x64xf32>
|
||||
%d = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d1, d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d0, d1)>],
|
||||
iterator_types = ["parallel", "parallel", "reduction"],
|
||||
kind = #vector.kind<add>}
|
||||
%a, %b, %c : vector<64x4xf32>, vector<64x4xf32> into vector<64x64xf32>
|
||||
%e = vector.transfer_read %D[%c0, %c0], %cf0 : memref<?x?xf32>, vector<64x64xf32>
|
||||
%r = addf %d, %e : vector<64x64xf32>
|
||||
vector.transfer_write %r, %C[%c0, %c0] : vector<64x64xf32>, memref<?x?xf32>
|
||||
return
|
||||
}
|
||||
|
|
|
@ -275,8 +275,7 @@ struct TestVectorDistributePatterns
|
|||
}
|
||||
}
|
||||
});
|
||||
patterns.add<PointwiseExtractPattern>(ctx);
|
||||
populateVectorToVectorTransformationPatterns(patterns);
|
||||
populatePropagateVectorDistributionPatterns(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
|
||||
}
|
||||
};
|
||||
|
@ -339,8 +338,7 @@ struct TestVectorToLoopPatterns
|
|||
}
|
||||
return mlir::WalkResult::interrupt();
|
||||
});
|
||||
patterns.add<PointwiseExtractPattern>(ctx);
|
||||
populateVectorToVectorTransformationPatterns(patterns);
|
||||
populatePropagateVectorDistributionPatterns(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
|
||||
}
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue