forked from OSchip/llvm-project
[mlir][vector] Add unit test for vector distribute by block
When distributing a vector larger than the given multiplicity, we can distribute it by block where each id gets a chunk of consecutive element along the dimension distributed. This adds a test for this case and adds extra checks to make sure we don't distribute for cases not multiple of multiplicity. Differential Revision: https://reviews.llvm.org/D89061
This commit is contained in:
parent
afff74e5c2
commit
cf402a1987
|
@ -2444,7 +2444,14 @@ mlir::vector::distributPointwiseVectorOp(OpBuilder &builder, Operation *op,
|
||||||
OpBuilder::InsertionGuard guard(builder);
|
OpBuilder::InsertionGuard guard(builder);
|
||||||
builder.setInsertionPointAfter(op);
|
builder.setInsertionPointAfter(op);
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
|
if (op->getNumResults() != 1)
|
||||||
|
return {};
|
||||||
Value result = op->getResult(0);
|
Value result = op->getResult(0);
|
||||||
|
VectorType type = op->getResult(0).getType().dyn_cast<VectorType>();
|
||||||
|
// Currently only support distributing 1-D vectors of size multiple of the
|
||||||
|
// given multiplicty. To handle more sizes we would need to support masking.
|
||||||
|
if (!type || type.getRank() != 1 || type.getNumElements() % multiplicity != 0)
|
||||||
|
return {};
|
||||||
DistributeOps ops;
|
DistributeOps ops;
|
||||||
ops.extract =
|
ops.extract =
|
||||||
builder.create<vector::ExtractMapOp>(loc, result, id, multiplicity);
|
builder.create<vector::ExtractMapOp>(loc, result, id, multiplicity);
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: mlir-opt %s -test-vector-distribute-patterns | FileCheck %s
|
// RUN: mlir-opt %s -test-vector-distribute-patterns=distribution-multiplicity=32 | FileCheck %s
|
||||||
|
|
||||||
// CHECK-LABEL: func @distribute_vector_add
|
// CHECK-LABEL: func @distribute_vector_add
|
||||||
// CHECK-SAME: (%[[ID:.*]]: index
|
// CHECK-SAME: (%[[ID:.*]]: index
|
||||||
|
@ -14,12 +14,12 @@ func @distribute_vector_add(%id : index, %A: vector<32xf32>, %B: vector<32xf32>)
|
||||||
|
|
||||||
// CHECK-LABEL: func @vector_add_read_write
|
// CHECK-LABEL: func @vector_add_read_write
|
||||||
// CHECK-SAME: (%[[ID:.*]]: index
|
// CHECK-SAME: (%[[ID:.*]]: index
|
||||||
// CHECK: %[[EXA:.*]] = vector.transfer_read %{{.*}}[%{{.*}}], %{{.*}} : memref<32xf32>, vector<1xf32>
|
// CHECK: %[[EXA:.*]] = vector.transfer_read %{{.*}}[%[[ID]]], %{{.*}} : memref<32xf32>, vector<1xf32>
|
||||||
// CHECK-NEXT: %[[EXB:.*]] = vector.transfer_read %{{.*}}[%{{.*}}], %{{.*}} : memref<32xf32>, vector<1xf32>
|
// CHECK-NEXT: %[[EXB:.*]] = vector.transfer_read %{{.*}}[%[[ID]]], %{{.*}} : memref<32xf32>, vector<1xf32>
|
||||||
// CHECK-NEXT: %[[ADD1:.*]] = addf %[[EXA]], %[[EXB]] : vector<1xf32>
|
// CHECK-NEXT: %[[ADD1:.*]] = addf %[[EXA]], %[[EXB]] : vector<1xf32>
|
||||||
// CHECK-NEXT: %[[EXC:.*]] = vector.transfer_read %{{.*}}[%{{.*}}], %{{.*}} : memref<32xf32>, vector<1xf32>
|
// CHECK-NEXT: %[[EXC:.*]] = vector.transfer_read %{{.*}}[%[[ID]]], %{{.*}} : memref<32xf32>, vector<1xf32>
|
||||||
// CHECK-NEXT: %[[ADD2:.*]] = addf %[[ADD1]], %[[EXC]] : vector<1xf32>
|
// CHECK-NEXT: %[[ADD2:.*]] = addf %[[ADD1]], %[[EXC]] : vector<1xf32>
|
||||||
// CHECK-NEXT: vector.transfer_write %[[ADD2]], %{{.*}}[%{{.*}}] : vector<1xf32>, memref<32xf32>
|
// CHECK-NEXT: vector.transfer_write %[[ADD2]], %{{.*}}[%[[ID]]] : vector<1xf32>, memref<32xf32>
|
||||||
// CHECK-NEXT: return
|
// CHECK-NEXT: return
|
||||||
func @vector_add_read_write(%id : index, %A: memref<32xf32>, %B: memref<32xf32>, %C: memref<32xf32>, %D: memref<32xf32>) {
|
func @vector_add_read_write(%id : index, %A: memref<32xf32>, %B: memref<32xf32>, %C: memref<32xf32>, %D: memref<32xf32>) {
|
||||||
%c0 = constant 0 : index
|
%c0 = constant 0 : index
|
||||||
|
@ -32,3 +32,41 @@ func @vector_add_read_write(%id : index, %A: memref<32xf32>, %B: memref<32xf32>,
|
||||||
vector.transfer_write %d, %D[%c0]: vector<32xf32>, memref<32xf32>
|
vector.transfer_write %d, %D[%c0]: vector<32xf32>, memref<32xf32>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @vector_add_cycle
|
||||||
|
// CHECK-SAME: (%[[ID:.*]]: index
|
||||||
|
// CHECK: %[[EXA:.*]] = vector.transfer_read %{{.*}}[%[[ID]]], %{{.*}} : memref<64xf32>, vector<2xf32>
|
||||||
|
// CHECK-NEXT: %[[EXB:.*]] = vector.transfer_read %{{.*}}[%[[ID]]], %{{.*}} : memref<64xf32>, vector<2xf32>
|
||||||
|
// CHECK-NEXT: %[[ADD:.*]] = addf %[[EXA]], %[[EXB]] : vector<2xf32>
|
||||||
|
// CHECK-NEXT: vector.transfer_write %[[ADD]], %{{.*}}[%[[ID]]] : vector<2xf32>, memref<64xf32>
|
||||||
|
// CHECK-NEXT: return
|
||||||
|
func @vector_add_cycle(%id : index, %A: memref<64xf32>, %B: memref<64xf32>, %C: memref<64xf32>) {
|
||||||
|
%c0 = constant 0 : index
|
||||||
|
%cf0 = constant 0.0 : f32
|
||||||
|
%a = vector.transfer_read %A[%c0], %cf0: memref<64xf32>, vector<64xf32>
|
||||||
|
%b = vector.transfer_read %B[%c0], %cf0: memref<64xf32>, vector<64xf32>
|
||||||
|
%acc = addf %a, %b: vector<64xf32>
|
||||||
|
vector.transfer_write %acc, %C[%c0]: vector<64xf32>, memref<64xf32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Negative test to make sure nothing is done in case the vector size is not a
|
||||||
|
// multiple of multiplicity.
|
||||||
|
// CHECK-LABEL: func @vector_negative_test
|
||||||
|
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||||
|
// CHECK: %[[EXA:.*]] = vector.transfer_read %{{.*}}[%[[C0]]], %{{.*}} : memref<64xf32>, vector<16xf32>
|
||||||
|
// CHECK-NEXT: %[[EXB:.*]] = vector.transfer_read %{{.*}}[%[[C0]]], %{{.*}} : memref<64xf32>, vector<16xf32>
|
||||||
|
// CHECK-NEXT: %[[ADD:.*]] = addf %[[EXA]], %[[EXB]] : vector<16xf32>
|
||||||
|
// CHECK-NEXT: vector.transfer_write %[[ADD]], %{{.*}}[%[[C0]]] {{.*}} : vector<16xf32>, memref<64xf32>
|
||||||
|
// CHECK-NEXT: return
|
||||||
|
func @vector_negative_test(%id : index, %A: memref<64xf32>, %B: memref<64xf32>, %C: memref<64xf32>) {
|
||||||
|
%c0 = constant 0 : index
|
||||||
|
%cf0 = constant 0.0 : f32
|
||||||
|
%a = vector.transfer_read %A[%c0], %cf0: memref<64xf32>, vector<16xf32>
|
||||||
|
%b = vector.transfer_read %B[%c0], %cf0: memref<64xf32>, vector<16xf32>
|
||||||
|
%acc = addf %a, %b: vector<16xf32>
|
||||||
|
vector.transfer_write %acc, %C[%c0]: vector<16xf32>, memref<64xf32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -127,10 +127,16 @@ struct TestVectorUnrollingPatterns
|
||||||
|
|
||||||
struct TestVectorDistributePatterns
|
struct TestVectorDistributePatterns
|
||||||
: public PassWrapper<TestVectorDistributePatterns, FunctionPass> {
|
: public PassWrapper<TestVectorDistributePatterns, FunctionPass> {
|
||||||
|
TestVectorDistributePatterns() = default;
|
||||||
|
TestVectorDistributePatterns(const TestVectorDistributePatterns &pass) {}
|
||||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
registry.insert<VectorDialect>();
|
registry.insert<VectorDialect>();
|
||||||
registry.insert<AffineDialect>();
|
registry.insert<AffineDialect>();
|
||||||
}
|
}
|
||||||
|
Option<int32_t> multiplicity{
|
||||||
|
*this, "distribution-multiplicity",
|
||||||
|
llvm::cl::desc("Set the multiplicity used for distributing vector"),
|
||||||
|
llvm::cl::init(32)};
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
MLIRContext *ctx = &getContext();
|
MLIRContext *ctx = &getContext();
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
|
@ -138,10 +144,11 @@ struct TestVectorDistributePatterns
|
||||||
func.walk([&](AddFOp op) {
|
func.walk([&](AddFOp op) {
|
||||||
OpBuilder builder(op);
|
OpBuilder builder(op);
|
||||||
Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
|
Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
|
||||||
builder, op.getOperation(), func.getArgument(0), 32);
|
builder, op.getOperation(), func.getArgument(0), multiplicity);
|
||||||
assert(ops.hasValue());
|
if (ops.hasValue()) {
|
||||||
SmallPtrSet<Operation *, 1> extractOp({ops->extract});
|
SmallPtrSet<Operation *, 1> extractOp({ops->extract});
|
||||||
op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp);
|
op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp);
|
||||||
|
}
|
||||||
});
|
});
|
||||||
patterns.insert<PointwiseExtractPattern>(ctx);
|
patterns.insert<PointwiseExtractPattern>(ctx);
|
||||||
populateVectorToVectorTransformationPatterns(patterns, ctx);
|
populateVectorToVectorTransformationPatterns(patterns, ctx);
|
||||||
|
|
Loading…
Reference in New Issue