DimOp folding for alloc/view dynamic dimensions

Signed-off-by: Uday Bondhugula <uday@polymagelabs.com>

Closes tensorflow/mlir#253

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/253 from bondhugula:dimop a4b464f24ae63fd259114558d87e11b8ee4dae86
PiperOrigin-RevId: 284169689
This commit is contained in:
Uday Bondhugula 2019-12-06 05:59:06 -08:00 committed by A. Unique TensorFlower
parent 84a6182ddd
commit 3ade6a7d15
4 changed files with 91 additions and 31 deletions

View File

@ -1364,11 +1364,26 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
else if (auto memrefType = opType.dyn_cast<MemRefType>())
indexSize = memrefType.getShape()[getIndex()];
if (indexSize >= 0)
if (!ShapedType::isDynamic(indexSize))
return IntegerAttr::get(IndexType::get(getContext()), indexSize);
// Fold dim to the size argument of a SubViewOp.
// Fold dim to the size argument for an AllocOp/ViewOp/SubViewOp.
auto memrefType = opType.dyn_cast<MemRefType>();
if (!memrefType)
return {};
// The size at getIndex() is now a dynamic size of a memref.
auto memref = memrefOrTensor()->getDefiningOp();
if (auto alloc = dyn_cast_or_null<AllocOp>(memref))
return *(alloc.getDynamicSizes().begin() +
memrefType.getDynamicDimIndex(getIndex()));
if (auto view = dyn_cast_or_null<ViewOp>(memref))
return *(view.getDynamicSizes().begin() +
memrefType.getDynamicDimIndex(getIndex()));
// The subview op here is expected to have rank dynamic sizes now.
if (auto subview = dyn_cast_or_null<SubViewOp>(memref)) {
auto sizes = subview.sizes();
if (!sizes.empty())

View File

@ -46,10 +46,7 @@ func @materialize_read_1d_partially_specialized(%dyn1 : index, %dyn2 : index, %d
}
// CHECK: %[[tensor:[0-9]+]] = alloc
// CHECK-NOT: {{.*}} dim %[[tensor]], 0
// CHECK: {{.*}} dim %[[tensor]], 1
// CHECK: {{.*}} dim %[[tensor]], 2
// CHECK-NOT: {{.*}} dim %[[tensor]], 3
// CHECK: {{.*}} dim %[[tensor]], 4
return
}
@ -66,36 +63,32 @@ func @materialize_read(%M: index, %N: index, %O: index, %P: index) {
// CHECK-NEXT: affine.for %[[I1:.*]] = 0 to %{{.*}} {
// CHECK-NEXT: affine.for %[[I2:.*]] = 0 to %{{.*}} {
// CHECK-NEXT: affine.for %[[I3:.*]] = 0 to %{{.*}} step 5 {
// CHECK: %[[D0:.*]] = dim %{{.*}}, 0 : memref<?x?x?x?xf32>
// CHECK-NEXT: %[[D1:.*]] = dim %{{.*}}, 1 : memref<?x?x?x?xf32>
// CHECK-NEXT: %[[D2:.*]] = dim %{{.*}}, 2 : memref<?x?x?x?xf32>
// CHECK-NEXT: %[[D3:.*]] = dim %{{.*}}, 3 : memref<?x?x?x?xf32>
// CHECK: %[[ALLOC:.*]] = alloc() : memref<5x4x3xf32>
// CHECK-NEXT: %[[VECTOR_VIEW:.*]] = vector.type_cast %[[ALLOC]] : memref<5x4x3xf32>
// CHECK-NEXT: loop.for %[[I4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] {
// CHECK-NEXT: loop.for %[[I5:.*]] = %[[C0]] to %[[C4]] step %[[C1]] {
// CHECK-NEXT: loop.for %[[I6:.*]] = %[[C0]] to %[[C5]] step %[[C1]] {
// CHECK-NEXT: {{.*}} = affine.apply #[[ADD]](%[[I0]], %[[I4]])
// CHECK-NEXT: {{.*}} = affine.apply #[[SUB]]()[%[[D0]]]
// CHECK-NEXT: {{.*}} = affine.apply #[[SUB]]()[%{{.*}}]
// CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}} : index
// CHECK-NEXT: {{.*}} = select
// CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}}, %[[C0]] : index
// CHECK-NEXT: %[[L0:.*]] = select
//
// CHECK-NEXT: {{.*}} = affine.apply #[[SUB]]()[%[[D1]]]
// CHECK-NEXT: {{.*}} = affine.apply #[[SUB]]()[%{{.*}}]
// CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}} : index
// CHECK-NEXT: {{.*}} = select
// CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}}, %[[C0]] : index
// CHECK-NEXT: %[[L1:.*]] = select
//
// CHECK-NEXT: {{.*}} = affine.apply #[[SUB]]()[%[[D2]]]
// CHECK-NEXT: {{.*}} = affine.apply #[[SUB]]()[%{{.*}}]
// CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}} : index
// CHECK-NEXT: {{.*}} = select
// CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}}, %[[C0]] : index
// CHECK-NEXT: %[[L2:.*]] = select
//
// CHECK-NEXT: {{.*}} = affine.apply #[[ADD]](%[[I3]], %[[I6]])
// CHECK-NEXT: {{.*}} = affine.apply #[[SUB]]()[%[[D3]]]
// CHECK-NEXT: {{.*}} = affine.apply #[[SUB]]()[%{{.*}}]
// CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}} : index
// CHECK-NEXT: {{.*}} = select
// CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}}, %[[C0]] : index
@ -144,10 +137,6 @@ func @materialize_write(%M: index, %N: index, %O: index, %P: index) {
// CHECK-NEXT: affine.for %[[I1:.*]] = 0 to %{{.*}} step 4 {
// CHECK-NEXT: affine.for %[[I2:.*]] = 0 to %{{.*}} {
// CHECK-NEXT: affine.for %[[I3:.*]] = 0 to %{{.*}} step 5 {
// CHECK: %[[D0:.*]] = dim %{{.*}}, 0 : memref<?x?x?x?xf32>
// CHECK-NEXT: %[[D1:.*]] = dim %{{.*}}, 1 : memref<?x?x?x?xf32>
// CHECK-NEXT: %[[D2:.*]] = dim %{{.*}}, 2 : memref<?x?x?x?xf32>
// CHECK-NEXT: %[[D3:.*]] = dim %{{.*}}, 3 : memref<?x?x?x?xf32>
// CHECK: %[[ALLOC:.*]] = alloc() : memref<5x4x3xf32>
// CHECK-NEXT: %[[VECTOR_VIEW:.*]] = vector.type_cast {{.*}} : memref<5x4x3xf32>
// CHECK: store %{{.*}}, {{.*}} : memref<vector<5x4x3xf32>>
@ -155,27 +144,27 @@ func @materialize_write(%M: index, %N: index, %O: index, %P: index) {
// CHECK-NEXT: loop.for %[[I5:.*]] = %[[C0]] to %[[C4]] step %[[C1]] {
// CHECK-NEXT: loop.for %[[I6:.*]] = %[[C0]] to %[[C5]] step %[[C1]] {
// CHECK-NEXT: {{.*}} = affine.apply #[[ADD]](%[[I0]], %[[I4]])
// CHECK-NEXT: {{.*}} = affine.apply #[[SUB]]()[%[[D0]]]
// CHECK-NEXT: {{.*}} = affine.apply #[[SUB]]()[%{{.*}}]
// CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}}, {{.*}} : index
// CHECK-NEXT: {{.*}} = select {{.*}}, {{.*}}, {{.*}} : index
// CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}}, %[[C0]] : index
// CHECK-NEXT: %[[S0:.*]] = select {{.*}}, %[[C0]], {{.*}} : index
//
// CHECK-NEXT: {{.*}} = affine.apply #[[ADD]](%[[I1]], %[[I5]])
// CHECK-NEXT: {{.*}} = affine.apply #[[SUB]]()[%[[D1]]]
// CHECK-NEXT: {{.*}} = affine.apply #[[SUB]]()[%{{.*}}]
// CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}}, {{.*}} : index
// CHECK-NEXT: {{.*}} = select {{.*}}, {{.*}}, {{.*}} : index
// CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}}, %[[C0]] : index
// CHECK-NEXT: %[[S1:.*]] = select {{.*}}, %[[C0]], {{.*}} : index
//
// CHECK-NEXT: {{.*}} = affine.apply #[[SUB]]()[%[[D2]]]
// CHECK-NEXT: {{.*}} = affine.apply #[[SUB]]()[%{{.*}}]
// CHECK-NEXT: {{.*}} = cmpi "slt", %[[I2]], %{{.*}} : index
// CHECK-NEXT: {{.*}} = select {{.*}}, %[[I2]], {{.*}} : index
// CHECK-NEXT: {{.*}} = cmpi "slt", %[[I2]], %[[C0]] : index
// CHECK-NEXT: %[[S2:.*]] = select {{.*}}, %[[C0]], {{.*}} : index
//
// CHECK-NEXT: {{.*}} = affine.apply #[[ADD]](%[[I3]], %[[I6]])
// CHECK-NEXT: {{.*}} = affine.apply #[[SUB]]()[%[[D3]]]
// CHECK-NEXT: {{.*}} = affine.apply #[[SUB]]()[%{{.*}}]
// CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}}, {{.*}} : index
// CHECK-NEXT: {{.*}} = select {{.*}}, {{.*}}, {{.*}} : index
// CHECK-NEXT: {{.*}} = cmpi "slt", {{.*}}, %[[C0]] : index

View File

@ -22,13 +22,13 @@ func @matmul(%arg0: memref<?xi8>, %M: index, %N: index, %K: index) {
linalg.matmul(%A, %B, %C) : memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>
return
}
// CHECK-LABEL: func @matmul(%{{.*}}: memref<?xi8>, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
// CHECK-LABEL: func @matmul(%{{.*}}: memref<?xi8>,
// CHECK-SAME: [[M:arg[0-9]+]]: index
// CHECK-SAME: [[N:arg[0-9]+]]: index
// CHECK-SAME: [[K:arg[0-9]+]]: index
// CHECK: %[[A:.*]] = std.view %{{.*}}[{{.*}}] : memref<?xi8> to memref<?x?xf32, #[[strided2D]]>
// CHECK: %[[B:.*]] = std.view %{{.*}}[{{.*}}] : memref<?xi8> to memref<?x?xf32, #[[strided2D]]>
// CHECK: %[[C:.*]] = std.view %{{.*}}[{{.*}}] : memref<?xi8> to memref<?x?xf32, #[[strided2D]]>
// CHECK: %[[M:.*]] = dim %[[A]], 0 : memref<?x?xf32, #[[strided2D]]>
// CHECK: %[[K:.*]] = dim %[[A]], 1 : memref<?x?xf32, #[[strided2D]]>
// CHECK: %[[N:.*]] = dim %[[B]], 1 : memref<?x?xf32, #[[strided2D]]>
// CHECK: loop.for %{{.*}} = %{{.*}} to %[[M]] step %{{.*}} {
// CHECK: loop.for %{{.*}} = %{{.*}} to %[[N]] step %{{.*}} {
// CHECK: loop.for %{{.*}} = %{{.*}} to %[[K]] step %{{.*}} {
@ -48,12 +48,12 @@ func @matvec(%arg0: memref<?xi8>, %M: index, %N: index) {
linalg.matvec(%2, %3, %4) : memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?xf32, offset: ?, strides: [1]>, memref<?xf32, offset: ?, strides: [1]>
return
}
// CHECK-LABEL: func @matvec(%{{.*}}: memref<?xi8>, %{{.*}}: index, %{{.*}}: index) {
// CHECK-LABEL: func @matvec(%{{.*}}: memref<?xi8>,
// CHECK-SAME: [[M:arg[0-9]+]]: index
// CHECK-SAME: [[K:arg[0-9]+]]: index
// CHECK: %[[A:.*]] = std.view %{{.*}}[{{.*}}] : memref<?xi8> to memref<?x?xf32, #[[strided2D]]>
// CHECK: %[[B:.*]] = std.view %{{.*}}[{{.*}}] : memref<?xi8> to memref<?xf32, #[[strided1D]]>
// CHECK: %[[C:.*]] = std.view %{{.*}}[{{.*}}] : memref<?xi8> to memref<?xf32, #[[strided1D]]>
// CHECK: %[[M:.*]] = dim %[[A]], 0 : memref<?x?xf32, #[[strided2D]]>
// CHECK: %[[K:.*]] = dim %[[A]], 1 : memref<?x?xf32, #[[strided2D]]>
// CHECK: loop.for %{{.*}} = %{{.*}} to %[[M]] step %{{.*}} {
// CHECK: loop.for %{{.*}} = %{{.*}} to %[[K]] step %{{.*}} {
// CHECK-DAG: %[[a:.*]] = load %[[A]][%{{.*}}, %{{.*}}] : memref<?x?xf32, #[[strided2D]]>
@ -72,11 +72,11 @@ func @dot(%arg0: memref<?xi8>, %M: index) {
linalg.dot(%1, %2, %3) : memref<?xf32, offset: ?, strides: [1]>, memref<?xf32, offset: ?, strides: [1]>, memref<f32>
return
}
// CHECK-LABEL: func @dot(%{{.*}}: memref<?xi8>, %{{.*}}: index) {
// CHECK-LABEL: func @dot(%{{.*}}: memref<?xi8>,
// CHECK-SAME: [[K:arg[0-9]+]]: index
// CHECK: %[[A:.*]] = std.view %{{.*}}[{{.*}}][{{.*}}] : memref<?xi8> to memref<?xf32, #[[strided1D]]>
// CHECK: %[[B:.*]] = std.view %{{.*}}[{{.*}}][{{.*}}] : memref<?xi8> to memref<?xf32, #[[strided1D]]>
// CHECK: %[[C:.*]] = std.view %{{.*}}[][] : memref<?xi8> to memref<f32>
// CHECK: %[[K:.*]] = dim %[[A]], 0 : memref<?xf32, #[[strided1D]]>
// CHECK: loop.for %{{.*}} = %{{.*}} to %[[K]] step %{{.*}} {
// CHECK-DAG: %[[a:.*]] = load %[[A]][%{{.*}}] : memref<?xf32, #[[strided1D]]>
// CHECK-DAG: %[[b:.*]] = load %[[B]][%{{.*}}] : memref<?xf32, #[[strided1D]]>

View File

@ -418,6 +418,62 @@ func @dyn_shape_fold(%L : index, %M : index) -> (memref<? x ? x i32>, memref<? x
return %c, %d : memref<? x ? x i32>, memref<? x ? x f32>
}
#map1 = (d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0)
#map2 = (d0, d1, d2)[s0, s1, s2] -> (d0 * s2 + d1 * s1 + d2 + s0)
// CHECK-LABEL: func @dim_op_fold(%arg0: index, %arg1: index, %arg2: index,
func @dim_op_fold(%arg0: index, %arg1: index, %arg2: index, %BUF: memref<?xi8>, %M : index, %N : index, %K : index) {
// CHECK-SAME: [[M:arg[0-9]+]]: index
// CHECK-SAME: [[N:arg[0-9]+]]: index
// CHECK-SAME: [[K:arg[0-9]+]]: index
%c0 = constant 0 : index
%c1 = constant 1 : index
%0 = alloc(%arg0, %arg1) : memref<?x?xf32>
%1 = alloc(%arg1, %arg2) : memref<?x8x?xf32>
%2 = dim %1, 2 : memref<?x8x?xf32>
affine.for %arg3 = 0 to %2 {
%3 = alloc(%arg0) : memref<?xi8>
%ub = dim %3, 0 : memref<?xi8>
affine.for %arg4 = 0 to %ub {
%s = dim %0, 0 : memref<?x?xf32>
%v = std.view %3[%c0][%arg4, %s] : memref<?xi8> to memref<?x?xf32, #map1>
%sv = std.subview %0[%c0, %c0][%s,%arg4][%c1,%c1] : memref<?x?xf32> to memref<?x?xf32, #map1>
%l = dim %v, 1 : memref<?x?xf32, #map1>
%u = dim %sv, 0 : memref<?x?xf32, #map1>
affine.for %arg5 = %l to %u {
"foo"() : () -> ()
}
}
}
// CHECK-NEXT: %c0 = constant 0 : index
// CHECK-NEXT: %c1 = constant 1 : index
// CHECK-NEXT: affine.for %arg7 = 0 to %arg2 {
// CHECK-NEXT: affine.for %arg8 = 0 to %arg0 {
// CHECK-NEXT: affine.for %arg9 = %arg0 to %arg0 {
// CHECK-NEXT: "foo"() : () -> ()
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
%A = view %BUF[%c0][%M, %K] : memref<?xi8> to memref<?x?xf32, offset: ?, strides: [?, 1]>
%B = view %BUF[%c0][%K, %N] : memref<?xi8> to memref<?x?xf32, offset: ?, strides: [?, 1]>
%C = view %BUF[%c0][%M, %N] : memref<?xi8> to memref<?x?xf32, offset: ?, strides: [?, 1]>
%M_ = dim %A, 0 : memref<?x?xf32, offset: ?, strides: [?, 1]>
%K_ = dim %A, 1 : memref<?x?xf32, offset: ?, strides: [?, 1]>
%N_ = dim %C, 1 : memref<?x?xf32, offset: ?, strides: [?, 1]>
loop.for %i = %c0 to %M_ step %c1 {
loop.for %j = %c0 to %N_ step %c1 {
loop.for %k = %c0 to %K_ step %c1 {
}
}
}
// CHECK: loop.for %{{.*}} = %c0 to %[[M]] step %c1 {
// CHECK: loop.for %arg8 = %c0 to %[[N]] step %c1 {
// CHECK: loop.for %arg9 = %c0 to %[[K]] step %c1 {
return
}
// CHECK-LABEL: func @merge_constants
func @merge_constants() -> (index, index) {
// CHECK-NEXT: %c42 = constant 42 : index
@ -743,7 +799,7 @@ func @subview(%arg0 : index, %arg1 : index) -> (index, index) {
load %4[%c0, %c0, %c0] : memref<?x?x?xf32,
(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + d1 * s2 + d2 * s3 + s0)>
// Test: subview offset operands are folded correctly w.r.t. base strides.
// Test: subview offset operands are folded correctly w.r.t. base strides.
// CHECK: std.subview %[[ALLOC0]][][][] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref<7x11x2xf32, #[[SUBVIEW_MAP1]]>
%5 = subview %0[%c1, %c2, %c7][%c7, %c11, %c2][%c1, %c1, %c1]
: memref<8x16x4xf32, (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)> to