forked from OSchip/llvm-project
DimOp folding for alloc/view dynamic dimensions
Signed-off-by: Uday Bondhugula <uday@polymagelabs.com> Closes tensorflow/mlir#253 COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/253 from bondhugula:dimop a4b464f24ae63fd259114558d87e11b8ee4dae86 PiperOrigin-RevId: 284169689
This commit is contained in:
parent
84a6182ddd
commit
3ade6a7d15
|
@ -1364,11 +1364,26 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
|
|||
else if (auto memrefType = opType.dyn_cast<MemRefType>())
|
||||
indexSize = memrefType.getShape()[getIndex()];
|
||||
|
||||
if (indexSize >= 0)
|
||||
if (!ShapedType::isDynamic(indexSize))
|
||||
return IntegerAttr::get(IndexType::get(getContext()), indexSize);
|
||||
|
||||
// Fold dim to the size argument of a SubViewOp.
|
||||
// Fold dim to the size argument for an AllocOp/ViewOp/SubViewOp.
|
||||
auto memrefType = opType.dyn_cast<MemRefType>();
|
||||
if (!memrefType)
|
||||
return {};
|
||||
|
||||
// The size at getIndex() is now a dynamic size of a memref.
|
||||
|
||||
auto memref = memrefOrTensor()->getDefiningOp();
|
||||
if (auto alloc = dyn_cast_or_null<AllocOp>(memref))
|
||||
return *(alloc.getDynamicSizes().begin() +
|
||||
memrefType.getDynamicDimIndex(getIndex()));
|
||||
|
||||
if (auto view = dyn_cast_or_null<ViewOp>(memref))
|
||||
return *(view.getDynamicSizes().begin() +
|
||||
memrefType.getDynamicDimIndex(getIndex()));
|
||||
|
||||
// The subview op here is expected to have rank dynamic sizes now.
|
||||
if (auto subview = dyn_cast_or_null<SubViewOp>(memref)) {
|
||||
auto sizes = subview.sizes();
|
||||
if (!sizes.empty())
|
||||
|
|
|
@ -46,10 +46,7 @@ func @materialize_read_1d_partially_specialized(%dyn1 : index, %dyn2 : index, %d
|
|||
}
|
||||
// CHECK: %[[tensor:[0-9]+]] = alloc
|
||||
// CHECK-NOT: {{.*}} dim %[[tensor]], 0
|
||||
// CHECK: {{.*}} dim %[[tensor]], 1
|
||||
// CHECK: {{.*}} dim %[[tensor]], 2
|
||||
// CHECK-NOT: {{.*}} dim %[[tensor]], 3
|
||||
// CHECK: {{.*}} dim %[[tensor]], 4
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -66,36 +63,32 @@ func @materialize_read(%M: index, %N: index, %O: index, %P: index) {
|
|||
// CHECK-NEXT: affine.for %[[I1:.*]] = 0 to %{{.*}} {
|
||||
// CHECK-NEXT: affine.for %[[I2:.*]] = 0 to %{{.*}} {
|
||||
// CHECK-NEXT: affine.for %[[I3:.*]] = 0 to %{{.*}} step 5 {
|
||||
// CHECK: %[[D0:.*]] = dim %{{.*}}, 0 : memref<?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[D1:.*]] = dim %{{.*}}, 1 : memref<?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[D2:.*]] = dim %{{.*}}, 2 : memref<?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[D3:.*]] = dim %{{.*}}, 3 : memref<?x?x?x?xf32>
|
||||
// CHECK: %[[ALLOC:.*]] = alloc() : memref<5x4x3xf32>
|
||||
// CHECK-NEXT: %[[VECTOR_VIEW:.*]] = vector.type_cast %[[ALLOC]] : memref<5x4x3xf32>
|
||||
// CHECK-NEXT: loop.for %[[I4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] {
|
||||
// CHECK-NEXT: loop.for %[[I5:.*]] = %[[C0]] to %[[C4]] step %[[C1]] {
|
||||
// CHECK-NEXT: loop.for %[[I6:.*]] = %[[C0]] to %[[C5]] step %[[C1]] {
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[ADD]](%[[I0]], %[[I4]])
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[SUB]]()[%[[D0]]]
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[SUB]]()[%{{.*}}]
|
||||
// CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}} : index
|
||||
// CHECK-NEXT: {{.*}} = select
|
||||
// CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}}, %[[C0]] : index
|
||||
// CHECK-NEXT: %[[L0:.*]] = select
|
||||
//
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[SUB]]()[%[[D1]]]
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[SUB]]()[%{{.*}}]
|
||||
// CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}} : index
|
||||
// CHECK-NEXT: {{.*}} = select
|
||||
// CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}}, %[[C0]] : index
|
||||
// CHECK-NEXT: %[[L1:.*]] = select
|
||||
//
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[SUB]]()[%[[D2]]]
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[SUB]]()[%{{.*}}]
|
||||
// CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}} : index
|
||||
// CHECK-NEXT: {{.*}} = select
|
||||
// CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}}, %[[C0]] : index
|
||||
// CHECK-NEXT: %[[L2:.*]] = select
|
||||
//
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[ADD]](%[[I3]], %[[I6]])
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[SUB]]()[%[[D3]]]
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[SUB]]()[%{{.*}}]
|
||||
// CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}} : index
|
||||
// CHECK-NEXT: {{.*}} = select
|
||||
// CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}}, %[[C0]] : index
|
||||
|
@ -144,10 +137,6 @@ func @materialize_write(%M: index, %N: index, %O: index, %P: index) {
|
|||
// CHECK-NEXT: affine.for %[[I1:.*]] = 0 to %{{.*}} step 4 {
|
||||
// CHECK-NEXT: affine.for %[[I2:.*]] = 0 to %{{.*}} {
|
||||
// CHECK-NEXT: affine.for %[[I3:.*]] = 0 to %{{.*}} step 5 {
|
||||
// CHECK: %[[D0:.*]] = dim %{{.*}}, 0 : memref<?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[D1:.*]] = dim %{{.*}}, 1 : memref<?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[D2:.*]] = dim %{{.*}}, 2 : memref<?x?x?x?xf32>
|
||||
// CHECK-NEXT: %[[D3:.*]] = dim %{{.*}}, 3 : memref<?x?x?x?xf32>
|
||||
// CHECK: %[[ALLOC:.*]] = alloc() : memref<5x4x3xf32>
|
||||
// CHECK-NEXT: %[[VECTOR_VIEW:.*]] = vector.type_cast {{.*}} : memref<5x4x3xf32>
|
||||
// CHECK: store %{{.*}}, {{.*}} : memref<vector<5x4x3xf32>>
|
||||
|
@ -155,27 +144,27 @@ func @materialize_write(%M: index, %N: index, %O: index, %P: index) {
|
|||
// CHECK-NEXT: loop.for %[[I5:.*]] = %[[C0]] to %[[C4]] step %[[C1]] {
|
||||
// CHECK-NEXT: loop.for %[[I6:.*]] = %[[C0]] to %[[C5]] step %[[C1]] {
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[ADD]](%[[I0]], %[[I4]])
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[SUB]]()[%[[D0]]]
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[SUB]]()[%{{.*}}]
|
||||
// CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}}, {{.*}} : index
|
||||
// CHECK-NEXT: {{.*}} = select {{.*}}, {{.*}}, {{.*}} : index
|
||||
// CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}}, %[[C0]] : index
|
||||
// CHECK-NEXT: %[[S0:.*]] = select {{.*}}, %[[C0]], {{.*}} : index
|
||||
//
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[ADD]](%[[I1]], %[[I5]])
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[SUB]]()[%[[D1]]]
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[SUB]]()[%{{.*}}]
|
||||
// CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}}, {{.*}} : index
|
||||
// CHECK-NEXT: {{.*}} = select {{.*}}, {{.*}}, {{.*}} : index
|
||||
// CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}}, %[[C0]] : index
|
||||
// CHECK-NEXT: %[[S1:.*]] = select {{.*}}, %[[C0]], {{.*}} : index
|
||||
//
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[SUB]]()[%[[D2]]]
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[SUB]]()[%{{.*}}]
|
||||
// CHECK-NEXT: {{.*}} = cmpi "slt", %[[I2]], %{{.*}} : index
|
||||
// CHECK-NEXT: {{.*}} = select {{.*}}, %[[I2]], {{.*}} : index
|
||||
// CHECK-NEXT: {{.*}} = cmpi "slt", %[[I2]], %[[C0]] : index
|
||||
// CHECK-NEXT: %[[S2:.*]] = select {{.*}}, %[[C0]], {{.*}} : index
|
||||
//
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[ADD]](%[[I3]], %[[I6]])
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[SUB]]()[%[[D3]]]
|
||||
// CHECK-NEXT: {{.*}} = affine.apply #[[SUB]]()[%{{.*}}]
|
||||
// CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}}, {{.*}} : index
|
||||
// CHECK-NEXT: {{.*}} = select {{.*}}, {{.*}}, {{.*}} : index
|
||||
// CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}}, %[[C0]] : index
|
||||
|
|
|
@ -22,13 +22,13 @@ func @matmul(%arg0: memref<?xi8>, %M: index, %N: index, %K: index) {
|
|||
linalg.matmul(%A, %B, %C) : memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @matmul(%{{.*}}: memref<?xi8>, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
|
||||
// CHECK-LABEL: func @matmul(%{{.*}}: memref<?xi8>,
|
||||
// CHECK-SAME: [[M:arg[0-9]+]]: index
|
||||
// CHECK-SAME: [[N:arg[0-9]+]]: index
|
||||
// CHECK-SAME: [[K:arg[0-9]+]]: index
|
||||
// CHECK: %[[A:.*]] = std.view %{{.*}}[{{.*}}] : memref<?xi8> to memref<?x?xf32, #[[strided2D]]>
|
||||
// CHECK: %[[B:.*]] = std.view %{{.*}}[{{.*}}] : memref<?xi8> to memref<?x?xf32, #[[strided2D]]>
|
||||
// CHECK: %[[C:.*]] = std.view %{{.*}}[{{.*}}] : memref<?xi8> to memref<?x?xf32, #[[strided2D]]>
|
||||
// CHECK: %[[M:.*]] = dim %[[A]], 0 : memref<?x?xf32, #[[strided2D]]>
|
||||
// CHECK: %[[K:.*]] = dim %[[A]], 1 : memref<?x?xf32, #[[strided2D]]>
|
||||
// CHECK: %[[N:.*]] = dim %[[B]], 1 : memref<?x?xf32, #[[strided2D]]>
|
||||
// CHECK: loop.for %{{.*}} = %{{.*}} to %[[M]] step %{{.*}} {
|
||||
// CHECK: loop.for %{{.*}} = %{{.*}} to %[[N]] step %{{.*}} {
|
||||
// CHECK: loop.for %{{.*}} = %{{.*}} to %[[K]] step %{{.*}} {
|
||||
|
@ -48,12 +48,12 @@ func @matvec(%arg0: memref<?xi8>, %M: index, %N: index) {
|
|||
linalg.matvec(%2, %3, %4) : memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?xf32, offset: ?, strides: [1]>, memref<?xf32, offset: ?, strides: [1]>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @matvec(%{{.*}}: memref<?xi8>, %{{.*}}: index, %{{.*}}: index) {
|
||||
// CHECK-LABEL: func @matvec(%{{.*}}: memref<?xi8>,
|
||||
// CHECK-SAME: [[M:arg[0-9]+]]: index
|
||||
// CHECK-SAME: [[K:arg[0-9]+]]: index
|
||||
// CHECK: %[[A:.*]] = std.view %{{.*}}[{{.*}}] : memref<?xi8> to memref<?x?xf32, #[[strided2D]]>
|
||||
// CHECK: %[[B:.*]] = std.view %{{.*}}[{{.*}}] : memref<?xi8> to memref<?xf32, #[[strided1D]]>
|
||||
// CHECK: %[[C:.*]] = std.view %{{.*}}[{{.*}}] : memref<?xi8> to memref<?xf32, #[[strided1D]]>
|
||||
// CHECK: %[[M:.*]] = dim %[[A]], 0 : memref<?x?xf32, #[[strided2D]]>
|
||||
// CHECK: %[[K:.*]] = dim %[[A]], 1 : memref<?x?xf32, #[[strided2D]]>
|
||||
// CHECK: loop.for %{{.*}} = %{{.*}} to %[[M]] step %{{.*}} {
|
||||
// CHECK: loop.for %{{.*}} = %{{.*}} to %[[K]] step %{{.*}} {
|
||||
// CHECK-DAG: %[[a:.*]] = load %[[A]][%{{.*}}, %{{.*}}] : memref<?x?xf32, #[[strided2D]]>
|
||||
|
@ -72,11 +72,11 @@ func @dot(%arg0: memref<?xi8>, %M: index) {
|
|||
linalg.dot(%1, %2, %3) : memref<?xf32, offset: ?, strides: [1]>, memref<?xf32, offset: ?, strides: [1]>, memref<f32>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @dot(%{{.*}}: memref<?xi8>, %{{.*}}: index) {
|
||||
// CHECK-LABEL: func @dot(%{{.*}}: memref<?xi8>,
|
||||
// CHECK-SAME: [[K:arg[0-9]+]]: index
|
||||
// CHECK: %[[A:.*]] = std.view %{{.*}}[{{.*}}][{{.*}}] : memref<?xi8> to memref<?xf32, #[[strided1D]]>
|
||||
// CHECK: %[[B:.*]] = std.view %{{.*}}[{{.*}}][{{.*}}] : memref<?xi8> to memref<?xf32, #[[strided1D]]>
|
||||
// CHECK: %[[C:.*]] = std.view %{{.*}}[][] : memref<?xi8> to memref<f32>
|
||||
// CHECK: %[[K:.*]] = dim %[[A]], 0 : memref<?xf32, #[[strided1D]]>
|
||||
// CHECK: loop.for %{{.*}} = %{{.*}} to %[[K]] step %{{.*}} {
|
||||
// CHECK-DAG: %[[a:.*]] = load %[[A]][%{{.*}}] : memref<?xf32, #[[strided1D]]>
|
||||
// CHECK-DAG: %[[b:.*]] = load %[[B]][%{{.*}}] : memref<?xf32, #[[strided1D]]>
|
||||
|
|
|
@ -418,6 +418,62 @@ func @dyn_shape_fold(%L : index, %M : index) -> (memref<? x ? x i32>, memref<? x
|
|||
return %c, %d : memref<? x ? x i32>, memref<? x ? x f32>
|
||||
}
|
||||
|
||||
#map1 = (d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0)
|
||||
#map2 = (d0, d1, d2)[s0, s1, s2] -> (d0 * s2 + d1 * s1 + d2 + s0)
|
||||
|
||||
// CHECK-LABEL: func @dim_op_fold(%arg0: index, %arg1: index, %arg2: index,
|
||||
func @dim_op_fold(%arg0: index, %arg1: index, %arg2: index, %BUF: memref<?xi8>, %M : index, %N : index, %K : index) {
|
||||
// CHECK-SAME: [[M:arg[0-9]+]]: index
|
||||
// CHECK-SAME: [[N:arg[0-9]+]]: index
|
||||
// CHECK-SAME: [[K:arg[0-9]+]]: index
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%0 = alloc(%arg0, %arg1) : memref<?x?xf32>
|
||||
%1 = alloc(%arg1, %arg2) : memref<?x8x?xf32>
|
||||
%2 = dim %1, 2 : memref<?x8x?xf32>
|
||||
affine.for %arg3 = 0 to %2 {
|
||||
%3 = alloc(%arg0) : memref<?xi8>
|
||||
%ub = dim %3, 0 : memref<?xi8>
|
||||
affine.for %arg4 = 0 to %ub {
|
||||
%s = dim %0, 0 : memref<?x?xf32>
|
||||
%v = std.view %3[%c0][%arg4, %s] : memref<?xi8> to memref<?x?xf32, #map1>
|
||||
%sv = std.subview %0[%c0, %c0][%s,%arg4][%c1,%c1] : memref<?x?xf32> to memref<?x?xf32, #map1>
|
||||
%l = dim %v, 1 : memref<?x?xf32, #map1>
|
||||
%u = dim %sv, 0 : memref<?x?xf32, #map1>
|
||||
affine.for %arg5 = %l to %u {
|
||||
"foo"() : () -> ()
|
||||
}
|
||||
}
|
||||
}
|
||||
// CHECK-NEXT: %c0 = constant 0 : index
|
||||
// CHECK-NEXT: %c1 = constant 1 : index
|
||||
// CHECK-NEXT: affine.for %arg7 = 0 to %arg2 {
|
||||
// CHECK-NEXT: affine.for %arg8 = 0 to %arg0 {
|
||||
// CHECK-NEXT: affine.for %arg9 = %arg0 to %arg0 {
|
||||
// CHECK-NEXT: "foo"() : () -> ()
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
|
||||
%A = view %BUF[%c0][%M, %K] : memref<?xi8> to memref<?x?xf32, offset: ?, strides: [?, 1]>
|
||||
%B = view %BUF[%c0][%K, %N] : memref<?xi8> to memref<?x?xf32, offset: ?, strides: [?, 1]>
|
||||
%C = view %BUF[%c0][%M, %N] : memref<?xi8> to memref<?x?xf32, offset: ?, strides: [?, 1]>
|
||||
|
||||
%M_ = dim %A, 0 : memref<?x?xf32, offset: ?, strides: [?, 1]>
|
||||
%K_ = dim %A, 1 : memref<?x?xf32, offset: ?, strides: [?, 1]>
|
||||
%N_ = dim %C, 1 : memref<?x?xf32, offset: ?, strides: [?, 1]>
|
||||
loop.for %i = %c0 to %M_ step %c1 {
|
||||
loop.for %j = %c0 to %N_ step %c1 {
|
||||
loop.for %k = %c0 to %K_ step %c1 {
|
||||
}
|
||||
}
|
||||
}
|
||||
// CHECK: loop.for %{{.*}} = %c0 to %[[M]] step %c1 {
|
||||
// CHECK: loop.for %arg8 = %c0 to %[[N]] step %c1 {
|
||||
// CHECK: loop.for %arg9 = %c0 to %[[K]] step %c1 {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @merge_constants
|
||||
func @merge_constants() -> (index, index) {
|
||||
// CHECK-NEXT: %c42 = constant 42 : index
|
||||
|
@ -743,7 +799,7 @@ func @subview(%arg0 : index, %arg1 : index) -> (index, index) {
|
|||
load %4[%c0, %c0, %c0] : memref<?x?x?xf32,
|
||||
(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + d1 * s2 + d2 * s3 + s0)>
|
||||
|
||||
// Test: subview offset operands are folded correctly w.r.t. base strides.
|
||||
// Test: subview offset operands are folded correctly w.r.t. base strides.
|
||||
// CHECK: std.subview %[[ALLOC0]][][][] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref<7x11x2xf32, #[[SUBVIEW_MAP1]]>
|
||||
%5 = subview %0[%c1, %c2, %c7][%c7, %c11, %c2][%c1, %c1, %c1]
|
||||
: memref<8x16x4xf32, (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)> to
|
||||
|
|
Loading…
Reference in New Issue