Update affine.load folding hook to fold global splat constant loads

Enhance affine.load folding hook to fold loads on global splat constant
memrefs.

Differential Revision: https://reviews.llvm.org/D122292
This commit is contained in:
Uday Bondhugula 2022-03-23 12:16:27 +05:30
parent 3427eddd9a
commit 5576579c86
2 changed files with 33 additions and 14 deletions

View File

@ -2410,17 +2410,22 @@ OpFoldResult AffineLoadOp::fold(ArrayRef<Attribute> cstOperands) {
SymbolTable::lookupSymbolIn(symbolTableOp, getGlobalOp.nameAttr()));
if (!global)
return {};
if (auto cstAttr =
global.getConstantInitValue().dyn_cast_or_null<DenseElementsAttr>()) {
// We can fold only if we know the indices.
if (!getAffineMap().isConstant())
return {};
auto indices = llvm::to_vector<4>(
llvm::map_range(getAffineMap().getConstantResults(),
[](int64_t v) -> uint64_t { return v; }));
return cstAttr.getValues<Attribute>()[indices];
}
return {};
// Check if the global memref is a constant.
auto cstAttr =
global.getConstantInitValue().dyn_cast_or_null<DenseElementsAttr>();
if (!cstAttr)
return {};
// If it's a splat constant, we can fold irrespective of indices.
if (auto splatAttr = cstAttr.dyn_cast<SplatElementsAttr>())
return splatAttr.getSplatValue<Attribute>();
// Otherwise, we can fold only if we know the indices.
if (!getAffineMap().isConstant())
return {};
auto indices = llvm::to_vector<4>(
llvm::map_range(getAffineMap().getConstantResults(),
[](int64_t v) -> uint64_t { return v; }));
return cstAttr.getValues<Attribute>()[indices];
}
//===----------------------------------------------------------------------===//

View File

@ -1101,6 +1101,7 @@ func @canonicalize_multi_min_max(%i0: index, %i1: index) -> (index, index) {
module {
memref.global "private" constant @__constant_1x5x1xf32 : memref<1x5x1xf32> = dense<[[[6.250000e-02], [2.500000e-01], [3.750000e-01], [2.500000e-01], [6.250000e-02]]]>
memref.global "private" constant @__constant_32x64xf32 : memref<32x64xf32> = dense<0.000000e+00>
// CHECK-LABEL: func @fold_const_init_global_memref
func @fold_const_init_global_memref() -> (f32, f32) {
%m = memref.get_global @__constant_1x5x1xf32 : memref<1x5x1xf32>
@ -1109,8 +1110,21 @@ module {
return %v0, %v1 : f32, f32
// CHECK-DAG: %[[C0:.*]] = arith.constant 6.250000e-02 : f32
// CHECK-DAG: %[[C1:.*]] = arith.constant 2.500000e-01 : f32
// CHECK-NEXT: return
// CHECK-SAME: %[[C0]]
// CHECK-SAME: %[[C1]]
// CHECK-NEXT: return %[[C0]], %[[C1]]
}
// CHECK-LABEL: func @fold_const_splat_global
func @fold_const_splat_global() -> memref<32x64xf32> {
// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
%m = memref.get_global @__constant_32x64xf32 : memref<32x64xf32>
%s = memref.alloc() : memref<32x64xf32>
affine.for %i = 0 to 32 {
affine.for %j = 0 to 64 {
%v = affine.load %m[%i, %j] : memref<32x64xf32>
affine.store %v, %s[%i, %j] : memref<32x64xf32>
// CHECK: affine.store %[[CST]], %{{.*}}
}
}
return %s: memref<32x64xf32>
}
}