[mlir] Support tensor types in unrolled VectorToSCF

Differential Revision: https://reviews.llvm.org/D102668
This commit is contained in:
Matthias Springer 2021-06-02 10:38:59 +09:00
parent 616ac1b961
commit bd20756d2c
2 changed files with 59 additions and 7 deletions

View File

@ -866,7 +866,7 @@ struct UnrollTransferReadConversion
PatternRewriter &rewriter) const override {
if (xferOp.getVectorType().getRank() <= options.targetRank)
return failure();
if (xferOp.getShapedType().template isa<RankedTensorType>())
if (isTensorOp(xferOp) && !options.lowerTensors)
return failure();
// Transfer ops that modify the element type are not supported atm.
if (xferOp.getVectorType().getElementType() !=
@ -988,7 +988,7 @@ struct UnrollTransferWriteConversion
PatternRewriter &rewriter) const override {
if (xferOp.getVectorType().getRank() <= options.targetRank)
return failure();
if (xferOp.getShapedType().template isa<RankedTensorType>())
if (isTensorOp(xferOp) && !options.lowerTensors)
return failure();
// Transfer ops that modify the element type are not supported atm.
if (xferOp.getVectorType().getElementType() !=
@ -998,15 +998,19 @@ struct UnrollTransferWriteConversion
auto vec = getDataVector(xferOp);
auto xferVecType = xferOp.getVectorType();
int64_t dimSize = xferVecType.getShape()[0];
auto source = xferOp.source(); // memref or tensor to be written to.
auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
// Generate fully unrolled loop of transfer ops.
Location loc = xferOp.getLoc();
for (int64_t i = 0; i < dimSize; ++i) {
Value iv = rewriter.create<ConstantIndexOp>(loc, i);
generateInBoundsCheck(
auto updatedSource = generateInBoundsCheck(
rewriter, xferOp, iv, unpackedDim(xferOp),
/*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
isTensorOp(xferOp) ? TypeRange(sourceType) : TypeRange(),
/*inBoundsCase=*/
[&](OpBuilder &b, Location loc) {
// Indices for the new transfer op.
SmallVector<Value, 8> xferIndices;
getXferIndices(b, xferOp, iv, xferIndices);
@ -1019,17 +1023,29 @@ struct UnrollTransferWriteConversion
auto extracted =
b.create<vector::ExtractOp>(loc, vec, extractionIndices);
auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
auto newXferOp = b.create<vector::TransferWriteOp>(
loc, Type(), extracted, xferOp.source(), xferIndices,
loc, sourceType, extracted, source, xferIndices,
AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
inBoundsAttr);
maybeAssignMask(b, xferOp, newXferOp, i);
return isTensorOp(xferOp) ? newXferOp->getResult(0) : Value();
},
/*outOfBoundsCase=*/
[&](OpBuilder &b, Location loc) {
return isTensorOp(xferOp) ? source : Value();
});
if (isTensorOp(xferOp))
source = updatedSource;
}
rewriter.eraseOp(xferOp);
if (isTensorOp(xferOp))
rewriter.replaceOp(xferOp, source);
else
rewriter.eraseOp(xferOp);
return success();
}
};

View File

@ -0,0 +1,36 @@
// RUN: mlir-opt %s -convert-vector-to-scf='full-unroll=true lower-tensors=true' -split-input-file -allow-unregistered-dialect | FileCheck %s
// CHECK-LABEL: func @transfer_read_2d(
// CHECK: %[[V_INIT:.*]] = constant dense<-4.200000e+01> : vector<4x9xf32>
// CHECK: %[[V0:.*]] = vector.transfer_read %{{.*}}[{{.*}}], %{{.*}} {in_bounds = [true]} : tensor<?x?xf32>, vector<9xf32>
// CHECK: %[[I0:.*]] = vector.insert %[[V0]], %[[V_INIT]] [0] : vector<9xf32> into vector<4x9xf32>
// CHECK: %[[V1:.*]] = vector.transfer_read %{{.*}}[{{.*}}], %{{.*}} {in_bounds = [true]} : tensor<?x?xf32>, vector<9xf32>
// CHECK: %[[I1:.*]] = vector.insert %[[V1]], %[[I0]] [1] : vector<9xf32> into vector<4x9xf32>
// CHECK: %[[V2:.*]] = vector.transfer_read %{{.*}}[{{.*}}], %{{.*}} {in_bounds = [true]} : tensor<?x?xf32>, vector<9xf32>
// CHECK: %[[I2:.*]] = vector.insert %[[V2]], %[[I1]] [2] : vector<9xf32> into vector<4x9xf32>
// CHECK: %[[V3:.*]] = vector.transfer_read %{{.*}}[{{.*}}], %{{.*}} {in_bounds = [true]} : tensor<?x?xf32>, vector<9xf32>
// CHECK: %[[I3:.*]] = vector.insert %[[V3]], %[[I2]] [3] : vector<9xf32> into vector<4x9xf32>
// CHECK: return %[[I3]] : vector<4x9xf32>
func @transfer_read_2d(%A : tensor<?x?xf32>, %base1 : index, %base2 : index)
-> (vector<4x9xf32>){
%p = constant -42.0: f32
%f = vector.transfer_read %A[%base1, %base2], %p {in_bounds = [true, true]}
: tensor<?x?xf32>, vector<4x9xf32>
return %f : vector<4x9xf32>
}
// -----
// CHECK-LABEL: func @transfer_write_2d(
// CHECK: %[[V0:.*]] = vector.extract %{{.*}}[0] : vector<2x3xf32>
// CHECK: %[[T0:.*]] = vector.transfer_write %[[V0]], %{{.*}}[{{.*}}] {in_bounds = [true]} : vector<3xf32>, tensor<?x?xf32>
// CHECK: %[[V1:.*]] = vector.extract %{{.*}}[1] : vector<2x3xf32>
// CHECK: %[[T1:.*]] = vector.transfer_write %[[V1]], %[[T0]][{{.*}}] {in_bounds = [true]} : vector<3xf32>, tensor<?x?xf32>
// CHECK: return %[[T1]] : tensor<?x?xf32>
func @transfer_write_2d(%A : tensor<?x?xf32>, %vec : vector<2x3xf32>,
%base1 : index, %base2 : index) -> (tensor<?x?xf32>) {
%t = vector.transfer_write %vec, %A[%base1, %base2] {in_bounds = [true, true]}
: vector<2x3xf32>, tensor<?x?xf32>
return %t : tensor<?x?xf32>
}