Fix lower/upper bound mismatch in stripmineSink

Also beef up the corresponding test case.

PiperOrigin-RevId: 236878818
This commit is contained in:
Nicolas Vasilache 2019-03-05 10:50:50 -08:00 committed by jpienaar
parent 2dfefdafea
commit 069c818f40
4 changed files with 26 additions and 17 deletions

View File

@ -27,8 +27,8 @@
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_LIB_EDSC_MLIREMITTER_H_
#define MLIR_LIB_EDSC_MLIREMITTER_H_
#ifndef MLIR_EDSC_MLIREMITTER_H_
#define MLIR_EDSC_MLIREMITTER_H_
#include <tuple>
#include <utility>
@ -185,4 +185,4 @@ private:
} // namespace edsc
} // namespace mlir
#endif // MLIR_LIB_EDSC_MLIREMITTER_H_
#endif // MLIR_EDSC_MLIREMITTER_H_

View File

@ -21,8 +21,8 @@
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_LIB_EDSC_TYPES_H_
#define MLIR_LIB_EDSC_TYPES_H_
#ifndef MLIR_EDSC_TYPES_H_
#define MLIR_EDSC_TYPES_H_
#include "mlir-c/Core.h"
#include "mlir/IR/OperationSupport.h"
@ -745,4 +745,4 @@ inline MinExpr Min(llvm::ArrayRef<Expr> args) { return MinExpr(args); }
} // namespace edsc
} // namespace mlir
#endif // MLIR_LIB_EDSC_TYPES_H_
#endif // MLIR_EDSC_TYPES_H_

View File

@ -528,7 +528,7 @@ stripmineSink(OpPointer<AffineForOp> forOp, uint64_t factor,
augmentMapAndBounds(&b, forOp->getInductionVar(), &lbMap, &lbOperands);
// Upper-bound map creation.
auto ubMap = forOp->getLowerBoundMap();
auto ubMap = forOp->getUpperBoundMap();
SmallVector<Value *, 4> ubOperands(forOp->getUpperBoundOperands());
augmentMapAndBounds(&b, forOp->getInductionVar(), &ubMap, &ubOperands,
/*offset=*/scaledStep);

View File

@ -402,16 +402,25 @@ TEST_FUNC(tile_2d) {
// clang-format off
// CHECK-LABEL: func @tile_2d
// CHECK: for %i0 = (d0) -> (d0)({{.*}}) to (d0) -> (d0)({{.*}}) step 512 {
// CHECK: for %i1 = (d0) -> (d0)({{.*}}) to (d0) -> (d0)({{.*}}) step 1024 {
// CHECK: for %i2 = (d0) -> (d0)({{.*}}) to (d0) -> (d0)({{.*}}) {
// CHECK: for %i3 = max {{.*}}, %i0) to min {{.*}}, %i0) step 16 {
// CHECK: for %i4 = max {{.*}}, %i1) to min {{.*}}, %i1) step 32 {
// CHECK: for %i5 = max {{.*}}, %i1, %i4) to min {{.*}}, %i1, %i4) {
// CHECK: for %i6 = max {{.*}}, %i0, %i3) to min {{.*}}, %i0, %i3) {
// CHECK: for %i7 = (d0) -> (d0)({{.*}}) to (d0) -> (d0)({{.*}}) {
// CHECK: for %i8 = max {{.*}}, %i0) to min {{.*}}, %i0) {
// CHECK: for %i9 = max {{.*}}, %i1) to min {{.*}}, %i1) {
// CHECK: [[ZERO:%.*]] = constant 0 : index
// CHECK: [[M:%[0-9]+]] = dim %arg0, 0 : memref<?x?x?xf32>
// CHECK: [[N:%[0-9]+]] = dim %arg0, 1 : memref<?x?x?xf32>
// CHECK: [[P:%[0-9]+]] = dim %arg0, 2 : memref<?x?x?xf32>
// CHECK: for %i0 = (d0) -> (d0)([[ZERO]]) to (d0) -> (d0)([[M]]) step 512 {
// CHECK-NEXT: for %i1 = (d0) -> (d0)([[ZERO]]) to (d0) -> (d0)([[N]]) step 1024 {
// CHECK-NEXT: for %i2 = (d0) -> (d0)([[ZERO]]) to (d0) -> (d0)([[P]]) {
// CHECK-NEXT: for %i3 = max (d0, d1) -> (d0, d1)([[ZERO]], %i0) to min (d0, d1) -> (d0, d1 + 512)(%0, %i0) step 16 {
// CHECK-NEXT: for %i4 = max (d0, d1) -> (d0, d1)([[ZERO]], %i1) to min (d0, d1) -> (d0, d1 + 1024)(%1, %i1) step 32 {
// CHECK-NEXT: for %i5 = max (d0, d1, d2) -> (d0, d1, d2)([[ZERO]], %i1, %i4) to min (d0, d1, d2) -> (d0, d1 + 1024, d2 + 32)(%1, %i1, %i4) {
// CHECK-NEXT: for %i6 = max (d0, d1, d2) -> (d0, d1, d2)([[ZERO]], %i0, %i3) to min (d0, d1, d2) -> (d0, d1 + 512, d2 + 16)(%0, %i0, %i3) {
// CHECK: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: for %i7 = (d0) -> (d0)([[ZERO]]) to (d0) -> (d0)(%2) {
// CHECK-NEXT: for %i8 = max (d0, d1) -> (d0, d1)([[ZERO]], %i0) to min (d0, d1) -> (d0, d1 + 512)(%0, %i0) {
// CHECK-NEXT: for %i9 = max (d0, d1) -> (d0, d1)([[ZERO]], %i1) to min (d0, d1) -> (d0, d1 + 1024)(%1, %i1) {
// clang-format on
f->print(llvm::outs());
}