forked from OSchip/llvm-project
Update affine.load folding hook to fold global splat constant loads
Enhance affine.load folding hook to fold loads on global splat constant memrefs. Differential Revision: https://reviews.llvm.org/D122292
This commit is contained in:
parent
3427eddd9a
commit
5576579c86
|
@ -2410,17 +2410,22 @@ OpFoldResult AffineLoadOp::fold(ArrayRef<Attribute> cstOperands) {
|
|||
SymbolTable::lookupSymbolIn(symbolTableOp, getGlobalOp.nameAttr()));
|
||||
if (!global)
|
||||
return {};
|
||||
if (auto cstAttr =
|
||||
global.getConstantInitValue().dyn_cast_or_null<DenseElementsAttr>()) {
|
||||
// We can fold only if we know the indices.
|
||||
if (!getAffineMap().isConstant())
|
||||
return {};
|
||||
auto indices = llvm::to_vector<4>(
|
||||
llvm::map_range(getAffineMap().getConstantResults(),
|
||||
[](int64_t v) -> uint64_t { return v; }));
|
||||
return cstAttr.getValues<Attribute>()[indices];
|
||||
}
|
||||
return {};
|
||||
|
||||
// Check if the global memref is a constant.
|
||||
auto cstAttr =
|
||||
global.getConstantInitValue().dyn_cast_or_null<DenseElementsAttr>();
|
||||
if (!cstAttr)
|
||||
return {};
|
||||
// If it's a splat constant, we can fold irrespective of indices.
|
||||
if (auto splatAttr = cstAttr.dyn_cast<SplatElementsAttr>())
|
||||
return splatAttr.getSplatValue<Attribute>();
|
||||
// Otherwise, we can fold only if we know the indices.
|
||||
if (!getAffineMap().isConstant())
|
||||
return {};
|
||||
auto indices = llvm::to_vector<4>(
|
||||
llvm::map_range(getAffineMap().getConstantResults(),
|
||||
[](int64_t v) -> uint64_t { return v; }));
|
||||
return cstAttr.getValues<Attribute>()[indices];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -1101,6 +1101,7 @@ func @canonicalize_multi_min_max(%i0: index, %i1: index) -> (index, index) {
|
|||
|
||||
module {
|
||||
memref.global "private" constant @__constant_1x5x1xf32 : memref<1x5x1xf32> = dense<[[[6.250000e-02], [2.500000e-01], [3.750000e-01], [2.500000e-01], [6.250000e-02]]]>
|
||||
memref.global "private" constant @__constant_32x64xf32 : memref<32x64xf32> = dense<0.000000e+00>
|
||||
// CHECK-LABEL: func @fold_const_init_global_memref
|
||||
func @fold_const_init_global_memref() -> (f32, f32) {
|
||||
%m = memref.get_global @__constant_1x5x1xf32 : memref<1x5x1xf32>
|
||||
|
@ -1109,8 +1110,21 @@ module {
|
|||
return %v0, %v1 : f32, f32
|
||||
// CHECK-DAG: %[[C0:.*]] = arith.constant 6.250000e-02 : f32
|
||||
// CHECK-DAG: %[[C1:.*]] = arith.constant 2.500000e-01 : f32
|
||||
// CHECK-NEXT: return
|
||||
// CHECK-SAME: %[[C0]]
|
||||
// CHECK-SAME: %[[C1]]
|
||||
// CHECK-NEXT: return %[[C0]], %[[C1]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @fold_const_splat_global
|
||||
func @fold_const_splat_global() -> memref<32x64xf32> {
|
||||
// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
|
||||
%m = memref.get_global @__constant_32x64xf32 : memref<32x64xf32>
|
||||
%s = memref.alloc() : memref<32x64xf32>
|
||||
affine.for %i = 0 to 32 {
|
||||
affine.for %j = 0 to 64 {
|
||||
%v = affine.load %m[%i, %j] : memref<32x64xf32>
|
||||
affine.store %v, %s[%i, %j] : memref<32x64xf32>
|
||||
// CHECK: affine.store %[[CST]], %{{.*}}
|
||||
}
|
||||
}
|
||||
return %s: memref<32x64xf32>
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue