forked from OSchip/llvm-project
Fix lower/upper bound mismatch in stripmineSink
Also beef up the corresponding test case. PiperOrigin-RevId: 236878818
This commit is contained in:
parent
2dfefdafea
commit
069c818f40
|
@ -27,8 +27,8 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_LIB_EDSC_MLIREMITTER_H_
|
||||
#define MLIR_LIB_EDSC_MLIREMITTER_H_
|
||||
#ifndef MLIR_EDSC_MLIREMITTER_H_
|
||||
#define MLIR_EDSC_MLIREMITTER_H_
|
||||
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
|
@ -185,4 +185,4 @@ private:
|
|||
} // namespace edsc
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_LIB_EDSC_MLIREMITTER_H_
|
||||
#endif // MLIR_EDSC_MLIREMITTER_H_
|
||||
|
|
|
@ -21,8 +21,8 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_LIB_EDSC_TYPES_H_
|
||||
#define MLIR_LIB_EDSC_TYPES_H_
|
||||
#ifndef MLIR_EDSC_TYPES_H_
|
||||
#define MLIR_EDSC_TYPES_H_
|
||||
|
||||
#include "mlir-c/Core.h"
|
||||
#include "mlir/IR/OperationSupport.h"
|
||||
|
@ -745,4 +745,4 @@ inline MinExpr Min(llvm::ArrayRef<Expr> args) { return MinExpr(args); }
|
|||
} // namespace edsc
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_LIB_EDSC_TYPES_H_
|
||||
#endif // MLIR_EDSC_TYPES_H_
|
||||
|
|
|
@ -528,7 +528,7 @@ stripmineSink(OpPointer<AffineForOp> forOp, uint64_t factor,
|
|||
augmentMapAndBounds(&b, forOp->getInductionVar(), &lbMap, &lbOperands);
|
||||
|
||||
// Upper-bound map creation.
|
||||
auto ubMap = forOp->getLowerBoundMap();
|
||||
auto ubMap = forOp->getUpperBoundMap();
|
||||
SmallVector<Value *, 4> ubOperands(forOp->getUpperBoundOperands());
|
||||
augmentMapAndBounds(&b, forOp->getInductionVar(), &ubMap, &ubOperands,
|
||||
/*offset=*/scaledStep);
|
||||
|
|
|
@ -402,16 +402,25 @@ TEST_FUNC(tile_2d) {
|
|||
|
||||
// clang-format off
|
||||
// CHECK-LABEL: func @tile_2d
|
||||
// CHECK: for %i0 = (d0) -> (d0)({{.*}}) to (d0) -> (d0)({{.*}}) step 512 {
|
||||
// CHECK: for %i1 = (d0) -> (d0)({{.*}}) to (d0) -> (d0)({{.*}}) step 1024 {
|
||||
// CHECK: for %i2 = (d0) -> (d0)({{.*}}) to (d0) -> (d0)({{.*}}) {
|
||||
// CHECK: for %i3 = max {{.*}}, %i0) to min {{.*}}, %i0) step 16 {
|
||||
// CHECK: for %i4 = max {{.*}}, %i1) to min {{.*}}, %i1) step 32 {
|
||||
// CHECK: for %i5 = max {{.*}}, %i1, %i4) to min {{.*}}, %i1, %i4) {
|
||||
// CHECK: for %i6 = max {{.*}}, %i0, %i3) to min {{.*}}, %i0, %i3) {
|
||||
// CHECK: for %i7 = (d0) -> (d0)({{.*}}) to (d0) -> (d0)({{.*}}) {
|
||||
// CHECK: for %i8 = max {{.*}}, %i0) to min {{.*}}, %i0) {
|
||||
// CHECK: for %i9 = max {{.*}}, %i1) to min {{.*}}, %i1) {
|
||||
// CHECK: [[ZERO:%.*]] = constant 0 : index
|
||||
// CHECK: [[M:%[0-9]+]] = dim %arg0, 0 : memref<?x?x?xf32>
|
||||
// CHECK: [[N:%[0-9]+]] = dim %arg0, 1 : memref<?x?x?xf32>
|
||||
// CHECK: [[P:%[0-9]+]] = dim %arg0, 2 : memref<?x?x?xf32>
|
||||
// CHECK: for %i0 = (d0) -> (d0)([[ZERO]]) to (d0) -> (d0)([[M]]) step 512 {
|
||||
// CHECK-NEXT: for %i1 = (d0) -> (d0)([[ZERO]]) to (d0) -> (d0)([[N]]) step 1024 {
|
||||
// CHECK-NEXT: for %i2 = (d0) -> (d0)([[ZERO]]) to (d0) -> (d0)([[P]]) {
|
||||
// CHECK-NEXT: for %i3 = max (d0, d1) -> (d0, d1)([[ZERO]], %i0) to min (d0, d1) -> (d0, d1 + 512)(%0, %i0) step 16 {
|
||||
// CHECK-NEXT: for %i4 = max (d0, d1) -> (d0, d1)([[ZERO]], %i1) to min (d0, d1) -> (d0, d1 + 1024)(%1, %i1) step 32 {
|
||||
// CHECK-NEXT: for %i5 = max (d0, d1, d2) -> (d0, d1, d2)([[ZERO]], %i1, %i4) to min (d0, d1, d2) -> (d0, d1 + 1024, d2 + 32)(%1, %i1, %i4) {
|
||||
// CHECK-NEXT: for %i6 = max (d0, d1, d2) -> (d0, d1, d2)([[ZERO]], %i0, %i3) to min (d0, d1, d2) -> (d0, d1 + 512, d2 + 16)(%0, %i0, %i3) {
|
||||
// CHECK: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: for %i7 = (d0) -> (d0)([[ZERO]]) to (d0) -> (d0)(%2) {
|
||||
// CHECK-NEXT: for %i8 = max (d0, d1) -> (d0, d1)([[ZERO]], %i0) to min (d0, d1) -> (d0, d1 + 512)(%0, %i0) {
|
||||
// CHECK-NEXT: for %i9 = max (d0, d1) -> (d0, d1)([[ZERO]], %i1) to min (d0, d1) -> (d0, d1 + 1024)(%1, %i1) {
|
||||
// clang-format on
|
||||
f->print(llvm::outs());
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue