fix memref flatten-load (#7298)

This commit is contained in:
Jiahan Xie 2024-10-30 20:42:20 -04:00 committed by GitHub
parent b49d2b3adc
commit 31e4f9eaae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 44 additions and 45 deletions

View File

@ -65,7 +65,8 @@ static Value flattenIndices(ConversionPatternRewriter &rewriter, Operation *op,
int64_t indexMulFactor = 1;
// Calculate the product of the i'th index and the [0:i-1] shape dims.
for (unsigned i = 0; i <= memIdx.index(); ++i) {
for (unsigned i = memIdx.index() + 1; i < memrefType.getShape().size();
++i) {
int64_t dimSize = memrefType.getShape()[i];
indexMulFactor *= dimSize;
}
@ -77,15 +78,15 @@ static Value flattenIndices(ConversionPatternRewriter &rewriter, Operation *op,
.create<arith::ConstantOp>(
loc, rewriter.getIndexAttr(llvm::Log2_64(indexMulFactor)))
.getResult();
partialIdx =
rewriter.create<arith::ShLIOp>(loc, partialIdx, constant).getResult();
finalIdx =
rewriter.create<arith::ShLIOp>(loc, finalIdx, constant).getResult();
} else {
auto constant = rewriter
.create<arith::ConstantOp>(
loc, rewriter.getIndexAttr(indexMulFactor))
.getResult();
partialIdx =
rewriter.create<arith::MulIOp>(loc, partialIdx, constant).getResult();
finalIdx =
rewriter.create<arith::MulIOp>(loc, finalIdx, constant).getResult();
}
// Sum up with the prior lower dimension accessors.

View File

@ -1,15 +1,15 @@
// RUN: circt-opt -split-input-file --flatten-memref %s | FileCheck %s
// CHECK-LABEL: func @as_func_arg(
// CHECK: %[[VAL_0:.*]]: memref<16xi32>,
// CHECK: %[[VAL_1:.*]]: index) -> i32 {
// CHECK-SAME: %[[VAL_0:.*]]: memref<16xi32>,
// CHECK-SAME: %[[VAL_1:.*]]: index) -> i32 {
// CHECK: %[[VAL_2:.*]] = arith.constant 2 : index
// CHECK: %[[VAL_3:.*]] = arith.shli %[[VAL_1]], %[[VAL_2]] : index
// CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_1]], %[[VAL_3]] : index
// CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_1]] : index
// CHECK: %[[VAL_5:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_4]]] : memref<16xi32>
// CHECK: %[[VAL_6:.*]] = arith.constant 2 : index
// CHECK: %[[VAL_7:.*]] = arith.shli %[[VAL_1]], %[[VAL_6]] : index
// CHECK: %[[VAL_8:.*]] = arith.addi %[[VAL_1]], %[[VAL_7]] : index
// CHECK: %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_1]] : index
// CHECK: memref.store %[[VAL_5]], %[[VAL_0]]{{\[}}%[[VAL_8]]] : memref<16xi32>
// CHECK: return %[[VAL_5]] : i32
// CHECK: }
@ -23,12 +23,12 @@ func.func @as_func_arg(%a : memref<4x4xi32>, %i : index) -> i32 {
// CHECK-LABEL: func @multidim3(
// CHECK: %[[VAL_0:.*]]: memref<210xi32>, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: index) -> i32 {
// CHECK: %[[VAL_4:.*]] = arith.constant 5 : index
// CHECK: %[[VAL_5:.*]] = arith.muli %[[VAL_2]], %[[VAL_4]] : index
// CHECK: %[[VAL_6:.*]] = arith.addi %[[VAL_1]], %[[VAL_5]] : index
// CHECK: %[[VAL_7:.*]] = arith.constant 30 : index
// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_3]], %[[VAL_7]] : index
// CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_6]], %[[VAL_8]] : index
// CHECK: %[[VAL_4:.*]] = arith.constant 42 : index
// CHECK: %[[VAL_5:.*]] = arith.muli %[[VAL_1]], %[[VAL_4]] : index
// CHECK: %[[VAL_6:.*]] = arith.addi %[[VAL_5]], %[[VAL_2]] : index
// CHECK: %[[VAL_7:.*]] = arith.constant 7 : index
// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_6]], %[[VAL_7]] : index
// CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_3]] : index
// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_9]]] : memref<210xi32>
// CHECK: return %[[VAL_10]] : i32
// CHECK: }
@ -40,20 +40,19 @@ func.func @multidim3(%a : memref<5x6x7xi32>, %i1 : index, %i2 : index, %i3 : ind
// -----
// CHECK-LABEL: func @multidim5(
// CHECK: %[[VAL_0:.*]]: memref<18900xi32>,
// CHECK: %[[VAL_1:.*]]: index) -> i32 {
// CHECK: %[[VAL_2:.*]] = arith.constant 5 : index
// CHECK: %[[VAL_0:.*]]: memref<18900xi32>, %[[VAL_1:.*]]: index) -> i32 {
// CHECK: %[[VAL_2:.*]] = arith.constant 3780 : index
// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_1]], %[[VAL_2]] : index
// CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_1]], %[[VAL_3]] : index
// CHECK: %[[VAL_5:.*]] = arith.constant 30 : index
// CHECK: %[[VAL_6:.*]] = arith.muli %[[VAL_1]], %[[VAL_5]] : index
// CHECK: %[[VAL_7:.*]] = arith.addi %[[VAL_4]], %[[VAL_6]] : index
// CHECK: %[[VAL_8:.*]] = arith.constant 210 : index
// CHECK: %[[VAL_9:.*]] = arith.muli %[[VAL_1]], %[[VAL_8]] : index
// CHECK: %[[VAL_10:.*]] = arith.addi %[[VAL_7]], %[[VAL_9]] : index
// CHECK: %[[VAL_11:.*]] = arith.constant 1890 : index
// CHECK: %[[VAL_12:.*]] = arith.muli %[[VAL_1]], %[[VAL_11]] : index
// CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_10]], %[[VAL_12]] : index
// CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_1]] : index
// CHECK: %[[VAL_5:.*]] = arith.constant 630 : index
// CHECK: %[[VAL_6:.*]] = arith.muli %[[VAL_4]], %[[VAL_5]] : index
// CHECK: %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_1]] : index
// CHECK: %[[VAL_8:.*]] = arith.constant 90 : index
// CHECK: %[[VAL_9:.*]] = arith.muli %[[VAL_7]], %[[VAL_8]] : index
// CHECK: %[[VAL_10:.*]] = arith.addi %[[VAL_9]], %[[VAL_1]] : index
// CHECK: %[[VAL_11:.*]] = arith.constant 10 : index
// CHECK: %[[VAL_12:.*]] = arith.muli %[[VAL_10]], %[[VAL_11]] : index
// CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_12]], %[[VAL_1]] : index
// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_13]]] : memref<18900xi32>
// CHECK: return %[[VAL_14]] : i32
// CHECK: }
@ -65,20 +64,19 @@ func.func @multidim5(%a : memref<5x6x7x9x10xi32>, %i : index) -> i32 {
// -----
// CHECK-LABEL: func @multidim5_p2(
// CHECK: %[[VAL_0:.*]]: memref<512xi32>,
// CHECK: %[[VAL_1:.*]]: index) -> i32 {
// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_0:.*]]: memref<512xi32>, %[[VAL_1:.*]]: index) -> i32 {
// CHECK: %[[VAL_2:.*]] = arith.constant 8 : index
// CHECK: %[[VAL_3:.*]] = arith.shli %[[VAL_1]], %[[VAL_2]] : index
// CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_1]], %[[VAL_3]] : index
// CHECK: %[[VAL_5:.*]] = arith.constant 3 : index
// CHECK: %[[VAL_6:.*]] = arith.shli %[[VAL_1]], %[[VAL_5]] : index
// CHECK: %[[VAL_7:.*]] = arith.addi %[[VAL_4]], %[[VAL_6]] : index
// CHECK: %[[VAL_8:.*]] = arith.constant 6 : index
// CHECK: %[[VAL_9:.*]] = arith.shli %[[VAL_1]], %[[VAL_8]] : index
// CHECK: %[[VAL_10:.*]] = arith.addi %[[VAL_7]], %[[VAL_9]] : index
// CHECK: %[[VAL_11:.*]] = arith.constant 7 : index
// CHECK: %[[VAL_12:.*]] = arith.shli %[[VAL_1]], %[[VAL_11]] : index
// CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_10]], %[[VAL_12]] : index
// CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_1]] : index
// CHECK: %[[VAL_5:.*]] = arith.constant 6 : index
// CHECK: %[[VAL_6:.*]] = arith.shli %[[VAL_4]], %[[VAL_5]] : index
// CHECK: %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_1]] : index
// CHECK: %[[VAL_8:.*]] = arith.constant 3 : index
// CHECK: %[[VAL_9:.*]] = arith.shli %[[VAL_7]], %[[VAL_8]] : index
// CHECK: %[[VAL_10:.*]] = arith.addi %[[VAL_9]], %[[VAL_1]] : index
// CHECK: %[[VAL_11:.*]] = arith.constant 2 : index
// CHECK: %[[VAL_12:.*]] = arith.shli %[[VAL_10]], %[[VAL_11]] : index
// CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_12]], %[[VAL_1]] : index
// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_13]]] : memref<512xi32>
// CHECK: return %[[VAL_14]] : i32
// CHECK: }
@ -117,12 +115,12 @@ func.func @allocs() -> memref<4x4xi32> {
// CHECK: cf.cond_br %[[VAL_2]], ^bb1(%[[VAL_3]], %[[VAL_4]] : memref<16xi32>, memref<16xi32>), ^bb1(%[[VAL_4]], %[[VAL_3]] : memref<16xi32>, memref<16xi32>)
// CHECK: ^bb1(%[[VAL_5:.*]]: memref<16xi32>, %[[VAL_6:.*]]: memref<16xi32>):
// CHECK: %[[VAL_7:.*]] = arith.constant 2 : index
// CHECK: %[[VAL_8:.*]] = arith.shli %[[VAL_1]], %[[VAL_7]] : index
// CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_0]], %[[VAL_8]] : index
// CHECK: %[[VAL_8:.*]] = arith.shli %[[VAL_0]], %[[VAL_7]] : index
// CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_1]] : index
// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_9]]] : memref<16xi32>
// CHECK: %[[VAL_11:.*]] = arith.constant 2 : index
// CHECK: %[[VAL_12:.*]] = arith.shli %[[VAL_1]], %[[VAL_11]] : index
// CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_0]], %[[VAL_12]] : index
// CHECK: %[[VAL_12:.*]] = arith.shli %[[VAL_0]], %[[VAL_11]] : index
// CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_12]], %[[VAL_1]] : index
// CHECK: memref.store %[[VAL_10]], %[[VAL_6]]{{\[}}%[[VAL_13]]] : memref<16xi32>
// CHECK: return
// CHECK: }