forked from OSchip/llvm-project
[mlir][Linalg] Use subview instead of linalg.slice in Promotion.cpp
This revision removes the reliance of Promotion on `linalg.slice` which is meant for the rank-reducing case. Differential Revision: https://reviews.llvm.org/D77676
This commit is contained in:
parent
530377018f
commit
3cb1f35df2
|
@ -39,17 +39,42 @@ using llvm::SetVector;
|
|||
|
||||
using folded_affine_min = folded::ValueBuilder<AffineMinOp>;
|
||||
using folded_linalg_range = folded::ValueBuilder<linalg::RangeOp>;
|
||||
using folded_std_dim = folded::ValueBuilder<DimOp>;
|
||||
using folded_std_subview = folded::ValueBuilder<SubViewOp>;
|
||||
using folded_std_view = folded::ValueBuilder<ViewOp>;
|
||||
|
||||
#define DEBUG_TYPE "linalg-promotion"
|
||||
|
||||
static Value allocBuffer(Type elementType, Value size, bool dynamicBuffers) {
|
||||
/// If `size` comes from an AffineMinOp and one of the dimensions of AffineMin
|
||||
/// is a constant then return a new value set to the smallest such constant.
|
||||
/// Otherwise return size.
|
||||
static Value extractSmallestConstantBoundingSize(OpBuilder &b, Location loc,
|
||||
Value size) {
|
||||
auto affineMinOp = dyn_cast_or_null<AffineMinOp>(size.getDefiningOp());
|
||||
if (!affineMinOp)
|
||||
return size;
|
||||
if (!llvm::any_of(affineMinOp.getAffineMap().getResults(), [](AffineExpr e) {
|
||||
return e.dyn_cast<AffineConstantExpr>();
|
||||
}))
|
||||
return size;
|
||||
int64_t minConst = std::numeric_limits<int64_t>::max();
|
||||
for (auto e : affineMinOp.getAffineMap().getResults())
|
||||
if (auto cst = e.dyn_cast<AffineConstantExpr>())
|
||||
minConst = std::min(minConst, cst.getValue());
|
||||
assert(minConst != std::numeric_limits<int64_t>::max());
|
||||
return b.create<ConstantIndexOp>(loc, minConst);
|
||||
}
|
||||
|
||||
static Value allocBuffer(Type elementType, Value size, bool dynamicBuffers,
|
||||
OperationFolder *folder) {
|
||||
auto *ctx = size.getContext();
|
||||
auto width = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8);
|
||||
if (!dynamicBuffers)
|
||||
if (auto cst = dyn_cast_or_null<ConstantIndexOp>(size.getDefiningOp()))
|
||||
return std_alloc(
|
||||
MemRefType::get(width * cst.getValue(), IntegerType::get(8, ctx)));
|
||||
Value mul = std_muli(std_constant_index(width), size);
|
||||
Value mul =
|
||||
folded_std_muli(folder, folded_std_constant_index(folder, width), size);
|
||||
return std_alloc(MemRefType::get(-1, IntegerType::get(8, ctx)), mul);
|
||||
}
|
||||
|
||||
|
@ -80,24 +105,28 @@ static PromotionInfo promoteFullTileBuffer(OpBuilder &b, Location loc,
|
|||
auto viewType = subView.getType();
|
||||
auto rank = viewType.getRank();
|
||||
Value allocSize = one;
|
||||
SmallVector<Value, 8> fullRanges, partialRanges;
|
||||
fullRanges.reserve(rank);
|
||||
partialRanges.reserve(rank);
|
||||
SmallVector<Value, 8> fullSizes, partialSizes;
|
||||
fullSizes.reserve(rank);
|
||||
partialSizes.reserve(rank);
|
||||
for (auto en : llvm::enumerate(subView.getRanges())) {
|
||||
auto rank = en.index();
|
||||
auto rangeValue = en.value();
|
||||
Value d = rangeValue.size;
|
||||
allocSize = folded_std_muli(folder, allocSize, d).getValue();
|
||||
fullRanges.push_back(d);
|
||||
partialRanges.push_back(
|
||||
folded_linalg_range(folder, zero, std_dim(subView, rank), one));
|
||||
// Try to extract a tight constant
|
||||
Value size = extractSmallestConstantBoundingSize(b, loc, rangeValue.size);
|
||||
allocSize = folded_std_muli(folder, allocSize, size).getValue();
|
||||
fullSizes.push_back(size);
|
||||
partialSizes.push_back(folded_std_dim(folder, subView, rank));
|
||||
}
|
||||
SmallVector<int64_t, 4> dynSizes(fullRanges.size(), -1);
|
||||
SmallVector<int64_t, 4> dynSizes(fullSizes.size(), -1);
|
||||
auto buffer =
|
||||
allocBuffer(viewType.getElementType(), allocSize, dynamicBuffers);
|
||||
auto fullLocalView = std_view(
|
||||
MemRefType::get(dynSizes, viewType.getElementType()), buffer, fullRanges);
|
||||
auto partialLocalView = linalg_slice(fullLocalView, partialRanges);
|
||||
allocBuffer(viewType.getElementType(), allocSize, dynamicBuffers, folder);
|
||||
auto fullLocalView = folded_std_view(
|
||||
folder, MemRefType::get(dynSizes, viewType.getElementType()), buffer,
|
||||
fullSizes);
|
||||
SmallVector<Value, 4> zeros(fullSizes.size(), zero);
|
||||
SmallVector<Value, 4> ones(fullSizes.size(), one);
|
||||
auto partialLocalView =
|
||||
folded_std_subview(folder, fullLocalView, zeros, partialSizes, ones);
|
||||
return PromotionInfo{buffer, fullLocalView, partialLocalView};
|
||||
}
|
||||
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
#map3 = affine_map<(d0) -> (d0 + 3)>
|
||||
|
||||
// CHECK-DAG: #[[strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
|
||||
// CHECK-DAG: #[[strided2DnoOffset:.*]] = affine_map<(d0, d1)[s0] -> (d0 * s0 + d1)>
|
||||
// CHECK-DAG: #[[strided2D_dynamic:.*]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
|
||||
|
||||
func @matmul_f32(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
|
||||
|
@ -46,28 +45,28 @@ func @matmul_f32(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
|
|||
// CHECK: %[[tmpA:.*]] = alloc() : memref<32xi8>
|
||||
// CHECK: %[[fullA:.*]] = std.view %[[tmpA]][][{{.*}}] : memref<32xi8> to memref<?x?xf32>
|
||||
// DYNAMIC: std.view %{{.*}}[][{{.*}}] : memref<?xi8> to memref<?x?xf32>
|
||||
// CHECK: %[[partialA:.*]] = linalg.slice %[[fullA]][%{{.*}}, %{{.*}}] : memref<?x?xf32>, !linalg.range, !linalg.range, memref<?x?xf32, #[[strided2DnoOffset]]>
|
||||
// CHECK: %[[partialA:.*]] = subview %[[fullA]]{{.*}} : memref<?x?xf32> to memref<?x?xf32, #[[strided2D_dynamic]]>
|
||||
///
|
||||
// CHECK: %[[tmpB:.*]] = alloc() : memref<48xi8>
|
||||
// CHECK: %[[fullB:.*]] = std.view %[[tmpB]][][{{.*}}] : memref<48xi8> to memref<?x?xf32>
|
||||
// DYNAMIC: std.view %{{.*}}[][{{.*}}] : memref<?xi8> to memref<?x?xf32>
|
||||
// CHECK: %[[partialB:.*]] = linalg.slice %[[fullB]][%{{.*}}, %{{.*}}] : memref<?x?xf32>, !linalg.range, !linalg.range, memref<?x?xf32, #[[strided2DnoOffset]]>
|
||||
// CHECK: %[[partialB:.*]] = subview %[[fullB]]{{.*}} : memref<?x?xf32> to memref<?x?xf32, #[[strided2D_dynamic]]>
|
||||
///
|
||||
// CHECK: %[[tmpC:.*]] = alloc() : memref<24xi8>
|
||||
// CHECK: %[[fullC:.*]] = std.view %[[tmpC]][][{{.*}}] : memref<24xi8> to memref<?x?xf32>
|
||||
// DYNAMIC: std.view %{{.*}}[][{{.*}}] : memref<?xi8> to memref<?x?xf32>
|
||||
// CHECK: %[[partialC:.*]] = linalg.slice %[[fullC]][%{{.*}}, %{{.*}}] : memref<?x?xf32>, !linalg.range, !linalg.range, memref<?x?xf32, #[[strided2DnoOffset]]>
|
||||
// CHECK: %[[partialC:.*]] = subview %[[fullC]]{{.*}} : memref<?x?xf32> to memref<?x?xf32, #[[strided2D_dynamic]]>
|
||||
|
||||
// CHECK: linalg.fill(%[[fullA]], {{.*}}) : memref<?x?xf32>, f32
|
||||
// CHECK: linalg.fill(%[[fullB]], {{.*}}) : memref<?x?xf32>, f32
|
||||
// CHECK: linalg.fill(%[[fullC]], {{.*}}) : memref<?x?xf32>, f32
|
||||
// CHECK: linalg.copy(%[[vA]], %[[partialA]]) : memref<?x?xf32, #[[strided2D_dynamic]]>, memref<?x?xf32, #[[strided2DnoOffset]]>
|
||||
// CHECK: linalg.copy(%[[vB]], %[[partialB]]) : memref<?x?xf32, #[[strided2D_dynamic]]>, memref<?x?xf32, #[[strided2DnoOffset]]>
|
||||
// CHECK: linalg.copy(%[[vC]], %[[partialC]]) : memref<?x?xf32, #[[strided2D_dynamic]]>, memref<?x?xf32, #[[strided2DnoOffset]]>
|
||||
// CHECK: linalg.copy(%[[vA]], %[[partialA]]) : memref<?x?xf32, #[[strided2D_dynamic]]>, memref<?x?xf32, #[[strided2D_dynamic]]>
|
||||
// CHECK: linalg.copy(%[[vB]], %[[partialB]]) : memref<?x?xf32, #[[strided2D_dynamic]]>, memref<?x?xf32, #[[strided2D_dynamic]]>
|
||||
// CHECK: linalg.copy(%[[vC]], %[[partialC]]) : memref<?x?xf32, #[[strided2D_dynamic]]>, memref<?x?xf32, #[[strided2D_dynamic]]>
|
||||
//
|
||||
// CHECK: linalg.matmul(%[[fullA]], %[[fullB]], %[[fullC]]) : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
|
||||
//
|
||||
// CHECK: linalg.copy(%[[partialC]], %[[vC]]) : memref<?x?xf32, #[[strided2DnoOffset]]>, memref<?x?xf32, #[[strided2D_dynamic]]>
|
||||
// CHECK: linalg.copy(%[[partialC]], %[[vC]]) : memref<?x?xf32, #[[strided2D_dynamic]]>, memref<?x?xf32, #[[strided2D_dynamic]]>
|
||||
//
|
||||
// CHECK: dealloc %[[tmpA]] : memref<32xi8>
|
||||
// CHECK: dealloc %[[tmpB]] : memref<48xi8>
|
||||
|
@ -111,28 +110,28 @@ func @matmul_f64(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
|
|||
// CHECK: %[[tmpA_f64:.*]] = alloc() : memref<64xi8>
|
||||
// CHECK: %[[fullA_f64:.*]] = std.view %[[tmpA_f64]][][{{.*}}] : memref<64xi8> to memref<?x?xf64>
|
||||
// DYNAMIC: std.view %{{.*}}[][{{.*}}] : memref<?xi8> to memref<?x?xf64>
|
||||
// CHECK: %[[partialA_f64:.*]] = linalg.slice %[[fullA_f64]][%{{.*}}, %{{.*}}] : memref<?x?xf64>, !linalg.range, !linalg.range, memref<?x?xf64, #[[strided2DnoOffset]]>
|
||||
// CHECK: %[[partialA_f64:.*]] = subview %[[fullA_f64]][%{{.*}}, %{{.*}}] : memref<?x?xf64> to memref<?x?xf64, #[[strided2D_dynamic]]>
|
||||
///
|
||||
// CHECK: %[[tmpB_f64:.*]] = alloc() : memref<96xi8>
|
||||
// CHECK: %[[fullB_f64:.*]] = std.view %[[tmpB_f64]][][{{.*}}] : memref<96xi8> to memref<?x?xf64>
|
||||
// DYNAMIC: std.view %{{.*}}[][{{.*}}] : memref<?xi8> to memref<?x?xf64>
|
||||
// CHECK: %[[partialB_f64:.*]] = linalg.slice %[[fullB_f64]][%{{.*}}, %{{.*}}] : memref<?x?xf64>, !linalg.range, !linalg.range, memref<?x?xf64, #[[strided2DnoOffset]]>
|
||||
// CHECK: %[[partialB_f64:.*]] = subview %[[fullB_f64]][%{{.*}}, %{{.*}}] : memref<?x?xf64> to memref<?x?xf64, #[[strided2D_dynamic]]>
|
||||
///
|
||||
// CHECK: %[[tmpC_f64:.*]] = alloc() : memref<48xi8>
|
||||
// CHECK: %[[fullC_f64:.*]] = std.view %[[tmpC_f64]][][{{.*}}] : memref<48xi8> to memref<?x?xf64>
|
||||
// DYNAMIC: std.view %{{.*}}[][{{.*}}] : memref<?xi8> to memref<?x?xf64>
|
||||
// CHECK: %[[partialC_f64:.*]] = linalg.slice %[[fullC_f64]][%{{.*}}, %{{.*}}] : memref<?x?xf64>, !linalg.range, !linalg.range, memref<?x?xf64, #[[strided2DnoOffset]]>
|
||||
// CHECK: %[[partialC_f64:.*]] = subview %[[fullC_f64]][%{{.*}}, %{{.*}}] : memref<?x?xf64> to memref<?x?xf64, #[[strided2D_dynamic]]>
|
||||
|
||||
// CHECK: linalg.fill(%[[fullA_f64]], {{.*}}) : memref<?x?xf64>, f64
|
||||
// CHECK: linalg.fill(%[[fullB_f64]], {{.*}}) : memref<?x?xf64>, f64
|
||||
// CHECK: linalg.fill(%[[fullC_f64]], {{.*}}) : memref<?x?xf64>, f64
|
||||
// CHECK: linalg.copy(%[[vA_f64]], %[[partialA_f64]]) : memref<?x?xf64, #[[strided2D_dynamic]]>, memref<?x?xf64, #[[strided2DnoOffset]]>
|
||||
// CHECK: linalg.copy(%[[vB_f64]], %[[partialB_f64]]) : memref<?x?xf64, #[[strided2D_dynamic]]>, memref<?x?xf64, #[[strided2DnoOffset]]>
|
||||
// CHECK: linalg.copy(%[[vC_f64]], %[[partialC_f64]]) : memref<?x?xf64, #[[strided2D_dynamic]]>, memref<?x?xf64, #[[strided2DnoOffset]]>
|
||||
// CHECK: linalg.copy(%[[vA_f64]], %[[partialA_f64]]) : memref<?x?xf64, #[[strided2D_dynamic]]>, memref<?x?xf64, #[[strided2D_dynamic]]>
|
||||
// CHECK: linalg.copy(%[[vB_f64]], %[[partialB_f64]]) : memref<?x?xf64, #[[strided2D_dynamic]]>, memref<?x?xf64, #[[strided2D_dynamic]]>
|
||||
// CHECK: linalg.copy(%[[vC_f64]], %[[partialC_f64]]) : memref<?x?xf64, #[[strided2D_dynamic]]>, memref<?x?xf64, #[[strided2D_dynamic]]>
|
||||
//
|
||||
// CHECK: linalg.matmul(%[[fullA_f64]], %[[fullB_f64]], %[[fullC_f64]]) : memref<?x?xf64>, memref<?x?xf64>, memref<?x?xf64>
|
||||
//
|
||||
// CHECK: linalg.copy(%[[partialC_f64]], %[[vC_f64]]) : memref<?x?xf64, #[[strided2DnoOffset]]>, memref<?x?xf64, #[[strided2D_dynamic]]>
|
||||
// CHECK: linalg.copy(%[[partialC_f64]], %[[vC_f64]]) : memref<?x?xf64, #[[strided2D_dynamic]]>, memref<?x?xf64, #[[strided2D_dynamic]]>
|
||||
//
|
||||
// CHECK: dealloc %[[tmpA_f64]] : memref<64xi8>
|
||||
// CHECK: dealloc %[[tmpB_f64]] : memref<96xi8>
|
||||
|
@ -176,28 +175,28 @@ func @matmul_i32(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
|
|||
// CHECK: %[[tmpA_i32:.*]] = alloc() : memref<32xi8>
|
||||
// CHECK: %[[fullA_i32:.*]] = std.view %[[tmpA_i32]][][{{.*}}] : memref<32xi8> to memref<?x?xi32>
|
||||
// DYNAMIC: std.view %{{.*}}[][{{.*}}] : memref<?xi8> to memref<?x?xi32>
|
||||
// CHECK: %[[partialA_i32:.*]] = linalg.slice %[[fullA_i32]][%{{.*}}, %{{.*}}] : memref<?x?xi32>, !linalg.range, !linalg.range, memref<?x?xi32, #[[strided2DnoOffset]]>
|
||||
// CHECK: %[[partialA_i32:.*]] = subview %[[fullA_i32]][%{{.*}}, %{{.*}}] : memref<?x?xi32> to memref<?x?xi32, #[[strided2D_dynamic]]>
|
||||
///
|
||||
// CHECK: %[[tmpB_i32:.*]] = alloc() : memref<48xi8>
|
||||
// CHECK: %[[fullB_i32:.*]] = std.view %[[tmpB_i32]][][{{.*}}] : memref<48xi8> to memref<?x?xi32>
|
||||
// DYNAMIC: std.view %{{.*}}[][{{.*}}] : memref<?xi8> to memref<?x?xi32>
|
||||
// CHECK: %[[partialB_i32:.*]] = linalg.slice %[[fullB_i32]][%{{.*}}, %{{.*}}] : memref<?x?xi32>, !linalg.range, !linalg.range, memref<?x?xi32, #[[strided2DnoOffset]]>
|
||||
// CHECK: %[[partialB_i32:.*]] = subview %[[fullB_i32]][%{{.*}}, %{{.*}}] : memref<?x?xi32> to memref<?x?xi32, #[[strided2D_dynamic]]>
|
||||
///
|
||||
// CHECK: %[[tmpC_i32:.*]] = alloc() : memref<24xi8>
|
||||
// CHECK: %[[fullC_i32:.*]] = std.view %[[tmpC_i32]][][{{.*}}] : memref<24xi8> to memref<?x?xi32>
|
||||
// DYNAMIC: std.view %{{.*}}[][{{.*}}] : memref<?xi8> to memref<?x?xi32>
|
||||
// CHECK: %[[partialC_i32:.*]] = linalg.slice %[[fullC_i32]][%{{.*}}, %{{.*}}] : memref<?x?xi32>, !linalg.range, !linalg.range, memref<?x?xi32, #[[strided2DnoOffset]]>
|
||||
// CHECK: %[[partialC_i32:.*]] = subview %[[fullC_i32]][%{{.*}}, %{{.*}}] : memref<?x?xi32> to memref<?x?xi32, #[[strided2D_dynamic]]>
|
||||
|
||||
// CHECK: linalg.fill(%[[fullA_i32]], {{.*}}) : memref<?x?xi32>, i32
|
||||
// CHECK: linalg.fill(%[[fullB_i32]], {{.*}}) : memref<?x?xi32>, i32
|
||||
// CHECK: linalg.fill(%[[fullC_i32]], {{.*}}) : memref<?x?xi32>, i32
|
||||
// CHECK: linalg.copy(%[[vA_i32]], %[[partialA_i32]]) : memref<?x?xi32, #[[strided2D_dynamic]]>, memref<?x?xi32, #[[strided2DnoOffset]]>
|
||||
// CHECK: linalg.copy(%[[vB_i32]], %[[partialB_i32]]) : memref<?x?xi32, #[[strided2D_dynamic]]>, memref<?x?xi32, #[[strided2DnoOffset]]>
|
||||
// CHECK: linalg.copy(%[[vC_i32]], %[[partialC_i32]]) : memref<?x?xi32, #[[strided2D_dynamic]]>, memref<?x?xi32, #[[strided2DnoOffset]]>
|
||||
// CHECK: linalg.copy(%[[vA_i32]], %[[partialA_i32]]) : memref<?x?xi32, #[[strided2D_dynamic]]>, memref<?x?xi32, #[[strided2D_dynamic]]>
|
||||
// CHECK: linalg.copy(%[[vB_i32]], %[[partialB_i32]]) : memref<?x?xi32, #[[strided2D_dynamic]]>, memref<?x?xi32, #[[strided2D_dynamic]]>
|
||||
// CHECK: linalg.copy(%[[vC_i32]], %[[partialC_i32]]) : memref<?x?xi32, #[[strided2D_dynamic]]>, memref<?x?xi32, #[[strided2D_dynamic]]>
|
||||
//
|
||||
// CHECK: linalg.matmul(%[[fullA_i32]], %[[fullB_i32]], %[[fullC_i32]]) : memref<?x?xi32>, memref<?x?xi32>, memref<?x?xi32>
|
||||
//
|
||||
// CHECK: linalg.copy(%[[partialC_i32]], %[[vC_i32]]) : memref<?x?xi32, #[[strided2DnoOffset]]>, memref<?x?xi32, #[[strided2D_dynamic]]>
|
||||
// CHECK: linalg.copy(%[[partialC_i32]], %[[vC_i32]]) : memref<?x?xi32, #[[strided2D_dynamic]]>, memref<?x?xi32, #[[strided2D_dynamic]]>
|
||||
//
|
||||
// CHECK: dealloc %[[tmpA_i32]] : memref<32xi8>
|
||||
// CHECK: dealloc %[[tmpB_i32]] : memref<48xi8>
|
||||
|
|
Loading…
Reference in New Issue