forked from OSchip/llvm-project
[mlir][vector] Extend vector distribution to all elementwise and contract
Uses elementwise interface to generalize canonicalization pattern and add a new pattern for vector.contract case. Differential Revision: https://reviews.llvm.org/D104343
This commit is contained in:
parent
0c400e8953
commit
627733b5f0
|
@ -105,6 +105,10 @@ void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns,
|
||||||
// a sequence of vector.reduction ops.
|
// a sequence of vector.reduction ops.
|
||||||
void populateVectorMultiReductionLoweringPatterns(RewritePatternSet &patterns);
|
void populateVectorMultiReductionLoweringPatterns(RewritePatternSet &patterns);
|
||||||
|
|
||||||
|
/// Collect a set of patterns to propagate insert_map/extract_map in the ssa
|
||||||
|
/// chain.
|
||||||
|
void populatePropagateVectorDistributionPatterns(RewritePatternSet &patterns);
|
||||||
|
|
||||||
/// An attribute that specifies the combining function for `vector.contract`,
|
/// An attribute that specifies the combining function for `vector.contract`,
|
||||||
/// and `vector.reduction`.
|
/// and `vector.reduction`.
|
||||||
class CombiningKindAttr
|
class CombiningKindAttr
|
||||||
|
|
|
@ -251,26 +251,6 @@ Optional<DistributeOps>
|
||||||
distributPointwiseVectorOp(OpBuilder &builder, Operation *op,
|
distributPointwiseVectorOp(OpBuilder &builder, Operation *op,
|
||||||
ArrayRef<Value> id, ArrayRef<int64_t> multiplicity,
|
ArrayRef<Value> id, ArrayRef<int64_t> multiplicity,
|
||||||
const AffineMap &map);
|
const AffineMap &map);
|
||||||
/// Canonicalize an extra element using the result of a pointwise operation.
|
|
||||||
/// Transforms:
|
|
||||||
/// %v = addf %a, %b : vector32xf32>
|
|
||||||
/// %dv = vector.extract_map %v, %id, 32 : vector<32xf32> into vector<1xf32>
|
|
||||||
/// to:
|
|
||||||
/// %da = vector.extract_map %a, %id, 32 : vector<32xf32> into vector<1xf32>
|
|
||||||
/// %db = vector.extract_map %a, %id, 32 : vector<32xf32> into vector<1xf32>
|
|
||||||
/// %dv = addf %da, %db : vector<1xf32>
|
|
||||||
struct PointwiseExtractPattern : public OpRewritePattern<ExtractMapOp> {
|
|
||||||
using FilterConstraintType = std::function<LogicalResult(ExtractMapOp op)>;
|
|
||||||
PointwiseExtractPattern(
|
|
||||||
MLIRContext *context, FilterConstraintType constraint =
|
|
||||||
[](ExtractMapOp op) { return success(); })
|
|
||||||
: OpRewritePattern<ExtractMapOp>(context), filter(constraint) {}
|
|
||||||
LogicalResult matchAndRewrite(ExtractMapOp extract,
|
|
||||||
PatternRewriter &rewriter) const override;
|
|
||||||
|
|
||||||
private:
|
|
||||||
FilterConstraintType filter;
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Implements transfer op write to read forwarding and dead transfer write
|
/// Implements transfer op write to read forwarding and dead transfer write
|
||||||
/// optimizations.
|
/// optimizations.
|
||||||
|
|
|
@ -2793,25 +2793,6 @@ LogicalResult mlir::vector::VectorTransferFullPartialRewriter::matchAndRewrite(
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult mlir::vector::PointwiseExtractPattern::matchAndRewrite(
|
|
||||||
ExtractMapOp extract, PatternRewriter &rewriter) const {
|
|
||||||
Operation *definedOp = extract.vector().getDefiningOp();
|
|
||||||
if (!definedOp || definedOp->getNumResults() != 1)
|
|
||||||
return failure();
|
|
||||||
// TODO: Create an interfaceOp for elementwise operations.
|
|
||||||
if (!isa<AddFOp>(definedOp))
|
|
||||||
return failure();
|
|
||||||
Location loc = extract.getLoc();
|
|
||||||
SmallVector<Value, 4> extractOperands;
|
|
||||||
for (OpOperand &operand : definedOp->getOpOperands())
|
|
||||||
extractOperands.push_back(rewriter.create<vector::ExtractMapOp>(
|
|
||||||
loc, extract.getResultType(), operand.get(), extract.ids()));
|
|
||||||
Operation *newOp = cloneOpWithOperandsAndTypes(
|
|
||||||
rewriter, loc, definedOp, extractOperands, extract.getResult().getType());
|
|
||||||
rewriter.replaceOp(extract, newOp->getResult(0));
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
Optional<mlir::vector::DistributeOps> mlir::vector::distributPointwiseVectorOp(
|
Optional<mlir::vector::DistributeOps> mlir::vector::distributPointwiseVectorOp(
|
||||||
OpBuilder &builder, Operation *op, ArrayRef<Value> ids,
|
OpBuilder &builder, Operation *op, ArrayRef<Value> ids,
|
||||||
ArrayRef<int64_t> multiplicity, const AffineMap &map) {
|
ArrayRef<int64_t> multiplicity, const AffineMap &map) {
|
||||||
|
@ -2843,6 +2824,91 @@ Optional<mlir::vector::DistributeOps> mlir::vector::distributPointwiseVectorOp(
|
||||||
return ops;
|
return ops;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Canonicalize an extract_map using the result of a pointwise operation.
|
||||||
|
/// Transforms:
|
||||||
|
/// %v = addf %a, %b : vector32xf32>
|
||||||
|
/// %dv = vector.extract_map %v[%id] : vector<32xf32> to vector<1xf32>
|
||||||
|
/// to:
|
||||||
|
/// %da = vector.extract_map %a[%id] : vector<32xf32> to vector<1xf32>
|
||||||
|
/// %db = vector.extract_map %a[%id] : vector<32xf32> to vector<1xf32>
|
||||||
|
/// %dv = addf %da, %db : vector<1xf32>
|
||||||
|
struct PointwiseExtractPattern : public OpRewritePattern<vector::ExtractMapOp> {
|
||||||
|
using OpRewritePattern<vector::ExtractMapOp>::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(vector::ExtractMapOp extract,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
Operation *definedOp = extract.vector().getDefiningOp();
|
||||||
|
if (!definedOp || !OpTrait::hasElementwiseMappableTraits(definedOp) ||
|
||||||
|
definedOp->getNumResults() != 1)
|
||||||
|
return failure();
|
||||||
|
Location loc = extract.getLoc();
|
||||||
|
SmallVector<Value, 4> extractOperands;
|
||||||
|
for (OpOperand &operand : definedOp->getOpOperands()) {
|
||||||
|
auto vecType = operand.get().getType().template dyn_cast<VectorType>();
|
||||||
|
if (!vecType) {
|
||||||
|
extractOperands.push_back(operand.get());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
extractOperands.push_back(rewriter.create<vector::ExtractMapOp>(
|
||||||
|
loc,
|
||||||
|
VectorType::get(extract.getResultType().getShape(),
|
||||||
|
vecType.getElementType()),
|
||||||
|
operand.get(), extract.ids()));
|
||||||
|
}
|
||||||
|
Operation *newOp = cloneOpWithOperandsAndTypes(
|
||||||
|
rewriter, loc, definedOp, extractOperands, extract.getResultType());
|
||||||
|
rewriter.replaceOp(extract, newOp->getResult(0));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Canonicalize an extract_map using the result of a contract operation.
|
||||||
|
/// This propagate the extract_map to operands.
|
||||||
|
struct ContractExtractPattern : public OpRewritePattern<vector::ExtractMapOp> {
|
||||||
|
using OpRewritePattern<vector::ExtractMapOp>::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(vector::ExtractMapOp extract,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
Operation *definedOp = extract.vector().getDefiningOp();
|
||||||
|
auto contract = dyn_cast_or_null<vector::ContractionOp>(definedOp);
|
||||||
|
if (!contract)
|
||||||
|
return failure();
|
||||||
|
Location loc = contract.getLoc();
|
||||||
|
unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
|
||||||
|
AffineMap affineMap = contract.getIndexingMaps()[accIndex];
|
||||||
|
// Create a map of the dimensions distributed based on the acc affine map.
|
||||||
|
// Only parallel dimensions are being distributed, reduction dimensions are
|
||||||
|
// untouched.
|
||||||
|
DenseMap<int64_t, int64_t> map;
|
||||||
|
for (unsigned i : llvm::seq(unsigned(0), affineMap.getNumResults()))
|
||||||
|
map[affineMap.getDimPosition(i)] = extract.getResultType().getDimSize(i);
|
||||||
|
SmallVector<Value, 4> extractOperands;
|
||||||
|
for (auto it : llvm::enumerate(contract.getIndexingMaps())) {
|
||||||
|
// For each operands calculate the new vector type after distribution.
|
||||||
|
Value operand = contract->getOperand(it.index());
|
||||||
|
auto vecType = operand.getType().cast<VectorType>();
|
||||||
|
SmallVector<int64_t> operandShape(vecType.getShape().begin(),
|
||||||
|
vecType.getShape().end());
|
||||||
|
for (unsigned i : llvm::seq(unsigned(0), it.value().getNumResults())) {
|
||||||
|
unsigned dim = it.value().getDimPosition(i);
|
||||||
|
auto distributedDim = map.find(dim);
|
||||||
|
// If the dimension is not in the map it means it is a reduction and
|
||||||
|
// doesn't get distributed.
|
||||||
|
if (distributedDim == map.end())
|
||||||
|
continue;
|
||||||
|
operandShape[i] = distributedDim->second;
|
||||||
|
}
|
||||||
|
VectorType newVecType =
|
||||||
|
VectorType::get(operandShape, vecType.getElementType());
|
||||||
|
extractOperands.push_back(rewriter.create<vector::ExtractMapOp>(
|
||||||
|
loc, newVecType, operand, extract.ids()));
|
||||||
|
}
|
||||||
|
Operation *newOp =
|
||||||
|
cloneOpWithOperandsAndTypes(rewriter, loc, definedOp, extractOperands,
|
||||||
|
extract.getResult().getType());
|
||||||
|
rewriter.replaceOp(extract, newOp->getResult(0));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
/// Converts TransferRead op used by ExtractMap op into a smaller dimension
|
/// Converts TransferRead op used by ExtractMap op into a smaller dimension
|
||||||
/// TransferRead.
|
/// TransferRead.
|
||||||
/// Example:
|
/// Example:
|
||||||
|
@ -4100,8 +4166,7 @@ void mlir::vector::populateVectorMaskMaterializationPatterns(
|
||||||
// TODO: Add this as DRR pattern.
|
// TODO: Add this as DRR pattern.
|
||||||
void mlir::vector::populateVectorToVectorTransformationPatterns(
|
void mlir::vector::populateVectorToVectorTransformationPatterns(
|
||||||
RewritePatternSet &patterns) {
|
RewritePatternSet &patterns) {
|
||||||
patterns.add<ShapeCastOpDecomposer, ShapeCastOpFolder, TupleGetFolderOp,
|
patterns.add<ShapeCastOpDecomposer, ShapeCastOpFolder, TupleGetFolderOp>(
|
||||||
TransferReadExtractPattern, TransferWriteInsertPattern>(
|
|
||||||
patterns.getContext());
|
patterns.getContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4112,6 +4177,13 @@ void mlir::vector::populateSplitVectorTransferPatterns(
|
||||||
ignoreFilter);
|
ignoreFilter);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void mlir::vector::populatePropagateVectorDistributionPatterns(
|
||||||
|
RewritePatternSet &patterns) {
|
||||||
|
patterns.add<PointwiseExtractPattern, ContractExtractPattern,
|
||||||
|
TransferReadExtractPattern, TransferWriteInsertPattern>(
|
||||||
|
patterns.getContext());
|
||||||
|
}
|
||||||
|
|
||||||
void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
|
void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
|
||||||
RewritePatternSet &patterns) {
|
RewritePatternSet &patterns) {
|
||||||
patterns.add<CastAwayExtractStridedSliceLeadingOneDim,
|
patterns.add<CastAwayExtractStridedSliceLeadingOneDim,
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
// RUN: mlir-opt %s -test-vector-distribute-patterns=distribution-multiplicity=32,1,32 -split-input-file | FileCheck %s
|
// RUN: mlir-opt %s -test-vector-distribute-patterns=distribution-multiplicity=32,1,32 -split-input-file | FileCheck %s
|
||||||
|
// RUN: mlir-opt %s -test-vector-distribute-patterns=distribution-multiplicity=32,4 -split-input-file | FileCheck %s --check-prefix=CHECK2D
|
||||||
|
|
||||||
// CHECK-LABEL: func @distribute_vector_add
|
// CHECK-LABEL: func @distribute_vector_add
|
||||||
// CHECK-SAME: (%[[ID:.*]]: index
|
// CHECK-SAME: (%[[ID:.*]]: index
|
||||||
|
@ -15,6 +16,24 @@ func @distribute_vector_add(%id : index, %A: vector<32xf32>, %B: vector<32xf32>)
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @distribute_vector_add_exp
|
||||||
|
// CHECK-SAME: (%[[ID:.*]]: index
|
||||||
|
// CHECK-NEXT: %[[EXPV:.*]] = math.exp %{{.*}} : vector<32xf32>
|
||||||
|
// CHECK-NEXT: %[[ADDV:.*]] = addf %[[EXPV]], %{{.*}} : vector<32xf32>
|
||||||
|
// CHECK-NEXT: %[[EXA:.*]] = vector.extract_map %{{.*}}[%[[ID]]] : vector<32xf32> to vector<1xf32>
|
||||||
|
// CHECK-NEXT: %[[EXC:.*]] = math.exp %[[EXA]] : vector<1xf32>
|
||||||
|
// CHECK-NEXT: %[[EXB:.*]] = vector.extract_map %{{.*}}[%[[ID]]] : vector<32xf32> to vector<1xf32>
|
||||||
|
// CHECK-NEXT: %[[ADD:.*]] = addf %[[EXC]], %[[EXB]] : vector<1xf32>
|
||||||
|
// CHECK-NEXT: %[[INS:.*]] = vector.insert_map %[[ADD]], %[[ADDV]][%[[ID]]] : vector<1xf32> into vector<32xf32>
|
||||||
|
// CHECK-NEXT: return %[[INS]] : vector<32xf32>
|
||||||
|
func @distribute_vector_add_exp(%id : index, %A: vector<32xf32>, %B: vector<32xf32>) -> vector<32xf32> {
|
||||||
|
%C = math.exp %A : vector<32xf32>
|
||||||
|
%0 = addf %C, %B : vector<32xf32>
|
||||||
|
return %0: vector<32xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @vector_add_read_write
|
// CHECK-LABEL: func @vector_add_read_write
|
||||||
// CHECK-SAME: (%[[ID:.*]]: index
|
// CHECK-SAME: (%[[ID:.*]]: index
|
||||||
// CHECK: %[[EXA:.*]] = vector.transfer_read %{{.*}}[%[[ID]]], %{{.*}} : memref<32xf32>, vector<1xf32>
|
// CHECK: %[[EXA:.*]] = vector.transfer_read %{{.*}}[%[[ID]]], %{{.*}} : memref<32xf32>, vector<1xf32>
|
||||||
|
@ -154,3 +173,32 @@ func @vector_add_transfer_permutation(%id0 : index, %id1 : index, %A: memref<?x?
|
||||||
vector.transfer_write %acc, %C[%c0, %c0, %c0, %c0] {permutation_map = #map2}: vector<64x4x32xf32>, memref<?x?x?x?xf32>
|
vector.transfer_write %acc, %C[%c0, %c0, %c0, %c0] {permutation_map = #map2}: vector<64x4x32xf32>, memref<?x?x?x?xf32>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK2D-LABEL: vector_add_contract
|
||||||
|
// CHECK2D: %[[A:.+]] = vector.transfer_read %arg2[%0, %c0], %cst : memref<?x?xf32>, vector<2x4xf32>
|
||||||
|
// CHECK2D: %[[B:.+]] = vector.transfer_read %arg3[%2, %c0], %cst : memref<?x?xf32>, vector<16x4xf32>
|
||||||
|
// CHECK2D: %[[C:.+]] = vector.transfer_read %arg4[%4, %5], %cst : memref<?x?xf32>, vector<2x16xf32>
|
||||||
|
// CHECK2D: %[[E:.+]] = vector.transfer_read %arg5[%7, %8], %cst : memref<?x?xf32>, vector<2x16xf32>
|
||||||
|
// CHECK2D: %[[D:.+]] = vector.contract {{.*}} %[[A]], %[[B]], %[[C]] : vector<2x4xf32>, vector<16x4xf32> into vector<2x16xf32>
|
||||||
|
// CHECK2D: %[[R:.+]] = addf %[[D]], %[[E]] : vector<2x16xf32>
|
||||||
|
// CHECK2D: vector.transfer_write %[[R]], {{.*}} : vector<2x16xf32>, memref<?x?xf32>
|
||||||
|
func @vector_add_contract(%id0 : index, %id1 : index, %A: memref<?x?xf32>,
|
||||||
|
%B: memref<?x?xf32>, %C: memref<?x?xf32>, %D: memref<?x?xf32>) {
|
||||||
|
%c0 = constant 0 : index
|
||||||
|
%cf0 = constant 0.0 : f32
|
||||||
|
%a = vector.transfer_read %A[%c0, %c0], %cf0 : memref<?x?xf32>, vector<64x4xf32>
|
||||||
|
%b = vector.transfer_read %B[%c0, %c0], %cf0 : memref<?x?xf32>, vector<64x4xf32>
|
||||||
|
%c = vector.transfer_read %C[%c0, %c0], %cf0 : memref<?x?xf32>, vector<64x64xf32>
|
||||||
|
%d = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
|
||||||
|
affine_map<(d0, d1, d2) -> (d1, d2)>,
|
||||||
|
affine_map<(d0, d1, d2) -> (d0, d1)>],
|
||||||
|
iterator_types = ["parallel", "parallel", "reduction"],
|
||||||
|
kind = #vector.kind<add>}
|
||||||
|
%a, %b, %c : vector<64x4xf32>, vector<64x4xf32> into vector<64x64xf32>
|
||||||
|
%e = vector.transfer_read %D[%c0, %c0], %cf0 : memref<?x?xf32>, vector<64x64xf32>
|
||||||
|
%r = addf %d, %e : vector<64x64xf32>
|
||||||
|
vector.transfer_write %r, %C[%c0, %c0] : vector<64x64xf32>, memref<?x?xf32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
|
@ -275,8 +275,7 @@ struct TestVectorDistributePatterns
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
patterns.add<PointwiseExtractPattern>(ctx);
|
populatePropagateVectorDistributionPatterns(patterns);
|
||||||
populateVectorToVectorTransformationPatterns(patterns);
|
|
||||||
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
|
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -339,8 +338,7 @@ struct TestVectorToLoopPatterns
|
||||||
}
|
}
|
||||||
return mlir::WalkResult::interrupt();
|
return mlir::WalkResult::interrupt();
|
||||||
});
|
});
|
||||||
patterns.add<PointwiseExtractPattern>(ctx);
|
populatePropagateVectorDistributionPatterns(patterns);
|
||||||
populateVectorToVectorTransformationPatterns(patterns);
|
|
||||||
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
|
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
Loading…
Reference in New Issue