forked from OSchip/llvm-project
[mlir] [VectorOps] Handle 'vector.shape_cast' lowering for all cases
Summary: Even though this operation is intended for 1d/2d conversions currently, leaving a semantic hole in the lowering prohibits proper testing of this operation. This CL adds a straightforward reference implementation for the missing cases. Reviewers: nicolasvasilache, mehdi_amini, ftynse, reidtatge Reviewed By: reidtatge Subscribers: mehdi_amini, rriddle, jpienaar, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, stephenneuendorffer, Joonsoo, grosul1, frgossen, Kayjukh, jurahul, msifontes Tags: #mlir Differential Revision: https://reviews.llvm.org/D81503
This commit is contained in:
parent
f56659d2ba
commit
1e45b55dcc
|
@ -1466,6 +1466,61 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
// We typically should not lower general shape cast operations into data
|
||||
// movement instructions, since the assumption is that these casts are
|
||||
// optimized away during progressive lowering. For completeness, however,
|
||||
// we fall back to a reference implementation that moves all elements
|
||||
// into the right place if we get here.
|
||||
class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
|
||||
public:
|
||||
using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(vector::ShapeCastOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
auto sourceVectorType = op.getSourceVectorType();
|
||||
auto resultVectorType = op.getResultVectorType();
|
||||
// Intended 2D/1D lowerings with better implementations.
|
||||
int64_t srcRank = sourceVectorType.getRank();
|
||||
int64_t resRank = resultVectorType.getRank();
|
||||
if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2))
|
||||
return failure();
|
||||
// Compute number of elements involved in the reshape.
|
||||
int64_t numElts = 1;
|
||||
for (int64_t r = 0; r < srcRank; r++)
|
||||
numElts *= sourceVectorType.getDimSize(r);
|
||||
// Replace with data movement operations:
|
||||
// x[0,0,0] = y[0,0]
|
||||
// x[0,0,1] = y[0,1]
|
||||
// x[0,1,0] = y[0,2]
|
||||
// etc., incrementing the two index vectors "row-major"
|
||||
// within the source and result shape.
|
||||
SmallVector<int64_t, 4> srcIdx(srcRank);
|
||||
SmallVector<int64_t, 4> resIdx(resRank);
|
||||
Value result = rewriter.create<ConstantOp>(
|
||||
loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
|
||||
for (int64_t i = 0; i < numElts; i++) {
|
||||
if (i != 0) {
|
||||
incIdx(srcIdx, sourceVectorType, srcRank - 1);
|
||||
incIdx(resIdx, resultVectorType, resRank - 1);
|
||||
}
|
||||
Value e = rewriter.create<vector::ExtractOp>(loc, op.source(), srcIdx);
|
||||
result = rewriter.create<vector::InsertOp>(loc, e, result, resIdx);
|
||||
}
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
static void incIdx(SmallVector<int64_t, 4> &idx, VectorType tp, int64_t r) {
|
||||
assert(0 <= r && r < tp.getRank());
|
||||
if (++idx[r] == tp.getDimSize(r)) {
|
||||
idx[r] = 0;
|
||||
incIdx(idx, tp, r - 1);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace mlir {
|
||||
|
@ -1864,7 +1919,8 @@ void mlir::vector::populateVectorContractLoweringPatterns(
|
|||
ConstantMaskOpLowering,
|
||||
OuterProductOpLowering,
|
||||
ShapeCastOp2DDownCastRewritePattern,
|
||||
ShapeCastOp2DUpCastRewritePattern>(context);
|
||||
ShapeCastOp2DUpCastRewritePattern,
|
||||
ShapeCastOpRewritePattern>(context);
|
||||
patterns.insert<TransposeOpLowering,
|
||||
ContractionOpLowering,
|
||||
ContractionOpToMatmulOpLowering,
|
||||
|
|
|
@ -319,7 +319,6 @@ func @transpose23(%arg0: vector<2x3xf32>) -> vector<3x2xf32> {
|
|||
return %0 : vector<3x2xf32>
|
||||
}
|
||||
|
||||
|
||||
// CHECK-LABEL: func @nop_shape_cast
|
||||
// CHECK-SAME: %[[A:.*]]: vector<16xf32>
|
||||
// CHECK: return %[[A]] : vector<16xf32>
|
||||
|
@ -378,6 +377,72 @@ func @shape_casts(%a: vector<2x2xf32>) -> (vector<4xf32>, vector<2x2xf32>) {
|
|||
return %r0, %1 : vector<4xf32>, vector<2x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @shape_cast_2d2d
|
||||
// CHECK-SAME: %[[A:.*]]: vector<3x2xf32>
|
||||
// CHECK: %[[C:.*]] = constant dense<0.000000e+00> : vector<2x3xf32>
|
||||
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<3x2xf32>
|
||||
// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0, 0] : f32 into vector<2x3xf32>
|
||||
// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 1] : vector<3x2xf32>
|
||||
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 1] : f32 into vector<2x3xf32>
|
||||
// CHECK: %[[T4:.*]] = vector.extract %[[A]][1, 0] : vector<3x2xf32>
|
||||
// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [0, 2] : f32 into vector<2x3xf32>
|
||||
// CHECK: %[[T6:.*]] = vector.extract %[[A]][1, 1] : vector<3x2xf32>
|
||||
// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 0] : f32 into vector<2x3xf32>
|
||||
// CHECK: %[[T8:.*]] = vector.extract %[[A]][2, 0] : vector<3x2xf32>
|
||||
// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 1] : f32 into vector<2x3xf32>
|
||||
// CHECK: %[[T10:.*]] = vector.extract %[[A]][2, 1] : vector<3x2xf32>
|
||||
// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [1, 2] : f32 into vector<2x3xf32>
|
||||
// CHECK: return %[[T11]] : vector<2x3xf32>
|
||||
|
||||
func @shape_cast_2d2d(%arg0 : vector<3x2xf32>) -> vector<2x3xf32> {
|
||||
%s = vector.shape_cast %arg0: vector<3x2xf32> to vector<2x3xf32>
|
||||
return %s : vector<2x3xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @shape_cast_3d1d
|
||||
// CHECK-SAME: %[[A:.*]]: vector<1x3x2xf32>
|
||||
// CHECK: %[[C:.*]] = constant dense<0.000000e+00> : vector<6xf32>
|
||||
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0, 0] : vector<1x3x2xf32>
|
||||
// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0] : f32 into vector<6xf32>
|
||||
// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 0, 1] : vector<1x3x2xf32>
|
||||
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1] : f32 into vector<6xf32>
|
||||
// CHECK: %[[T4:.*]] = vector.extract %[[A]][0, 1, 0] : vector<1x3x2xf32>
|
||||
// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [2] : f32 into vector<6xf32>
|
||||
// CHECK: %[[T6:.*]] = vector.extract %[[A]][0, 1, 1] : vector<1x3x2xf32>
|
||||
// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [3] : f32 into vector<6xf32>
|
||||
// CHECK: %[[T8:.*]] = vector.extract %[[A]][0, 2, 0] : vector<1x3x2xf32>
|
||||
// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [4] : f32 into vector<6xf32>
|
||||
// CHECK: %[[T10:.*]] = vector.extract %[[A]][0, 2, 1] : vector<1x3x2xf32>
|
||||
// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [5] : f32 into vector<6xf32>
|
||||
// CHECK: return %[[T11]] : vector<6xf32>
|
||||
|
||||
func @shape_cast_3d1d(%arg0 : vector<1x3x2xf32>) -> vector<6xf32> {
|
||||
%s = vector.shape_cast %arg0 : vector<1x3x2xf32> to vector<6xf32>
|
||||
return %s : vector<6xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @shape_cast_1d3d
|
||||
// CHECK-SAME: %[[A:.*]]: vector<6xf32>
|
||||
// CHECK: %[[C:.*]] = constant dense<0.000000e+00> : vector<2x1x3xf32>
|
||||
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<6xf32>
|
||||
// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0, 0, 0] : f32 into vector<2x1x3xf32>
|
||||
// CHECK: %[[T2:.*]] = vector.extract %[[A]][1] : vector<6xf32>
|
||||
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 0, 1] : f32 into vector<2x1x3xf32>
|
||||
// CHECK: %[[T4:.*]] = vector.extract %[[A]][2] : vector<6xf32>
|
||||
// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [0, 0, 2] : f32 into vector<2x1x3xf32>
|
||||
// CHECK: %[[T6:.*]] = vector.extract %[[A]][3] : vector<6xf32>
|
||||
// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 0, 0] : f32 into vector<2x1x3xf32>
|
||||
// CHECK: %[[T8:.*]] = vector.extract %[[A]][4] : vector<6xf32>
|
||||
// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 0, 1] : f32 into vector<2x1x3xf32>
|
||||
// CHECK: %[[T10:.*]] = vector.extract %[[A]][5] : vector<6xf32>
|
||||
// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [1, 0, 2] : f32 into vector<2x1x3xf32>
|
||||
// CHECK: return %[[T11]] : vector<2x1x3xf32>
|
||||
|
||||
func @shape_cast_1d3d(%arg0 : vector<6xf32>) -> vector<2x1x3xf32> {
|
||||
%s = vector.shape_cast %arg0 : vector<6xf32> to vector<2x1x3xf32>
|
||||
return %s : vector<2x1x3xf32>
|
||||
}
|
||||
|
||||
// MATRIX-LABEL: func @matmul
|
||||
// MATRIX-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
|
||||
// MATRIX-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>,
|
||||
|
|
Loading…
Reference in New Issue