[mlir][Linalg] Use subview instead of linalg.slice in Promotion.cpp

This revision removes the reliance of Promotion on `linalg.slice` which is meant
for the rank-reducing case.

Differential Revision: https://reviews.llvm.org/D77676
This commit is contained in:
Nicolas Vasilache 2020-04-07 23:49:08 -04:00
parent 530377018f
commit 3cb1f35df2
2 changed files with 65 additions and 37 deletions

View File

@ -39,17 +39,42 @@ using llvm::SetVector;
using folded_affine_min = folded::ValueBuilder<AffineMinOp>;
using folded_linalg_range = folded::ValueBuilder<linalg::RangeOp>;
using folded_std_dim = folded::ValueBuilder<DimOp>;
using folded_std_subview = folded::ValueBuilder<SubViewOp>;
using folded_std_view = folded::ValueBuilder<ViewOp>;
#define DEBUG_TYPE "linalg-promotion"
static Value allocBuffer(Type elementType, Value size, bool dynamicBuffers) {
/// If `size` comes from an AffineMinOp and one of the dimensions of AffineMin
/// is a constant then return a new value set to the smallest such constant.
/// Otherwise return size.
static Value extractSmallestConstantBoundingSize(OpBuilder &b, Location loc,
Value size) {
auto affineMinOp = dyn_cast_or_null<AffineMinOp>(size.getDefiningOp());
if (!affineMinOp)
return size;
if (!llvm::any_of(affineMinOp.getAffineMap().getResults(), [](AffineExpr e) {
return e.dyn_cast<AffineConstantExpr>();
}))
return size;
int64_t minConst = std::numeric_limits<int64_t>::max();
for (auto e : affineMinOp.getAffineMap().getResults())
if (auto cst = e.dyn_cast<AffineConstantExpr>())
minConst = std::min(minConst, cst.getValue());
assert(minConst != std::numeric_limits<int64_t>::max());
return b.create<ConstantIndexOp>(loc, minConst);
}
static Value allocBuffer(Type elementType, Value size, bool dynamicBuffers,
OperationFolder *folder) {
auto *ctx = size.getContext();
auto width = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8);
if (!dynamicBuffers)
if (auto cst = dyn_cast_or_null<ConstantIndexOp>(size.getDefiningOp()))
return std_alloc(
MemRefType::get(width * cst.getValue(), IntegerType::get(8, ctx)));
Value mul = std_muli(std_constant_index(width), size);
Value mul =
folded_std_muli(folder, folded_std_constant_index(folder, width), size);
return std_alloc(MemRefType::get(-1, IntegerType::get(8, ctx)), mul);
}
@ -80,24 +105,28 @@ static PromotionInfo promoteFullTileBuffer(OpBuilder &b, Location loc,
auto viewType = subView.getType();
auto rank = viewType.getRank();
Value allocSize = one;
SmallVector<Value, 8> fullRanges, partialRanges;
fullRanges.reserve(rank);
partialRanges.reserve(rank);
SmallVector<Value, 8> fullSizes, partialSizes;
fullSizes.reserve(rank);
partialSizes.reserve(rank);
for (auto en : llvm::enumerate(subView.getRanges())) {
auto rank = en.index();
auto rangeValue = en.value();
Value d = rangeValue.size;
allocSize = folded_std_muli(folder, allocSize, d).getValue();
fullRanges.push_back(d);
partialRanges.push_back(
folded_linalg_range(folder, zero, std_dim(subView, rank), one));
// Try to extract a tight constant
Value size = extractSmallestConstantBoundingSize(b, loc, rangeValue.size);
allocSize = folded_std_muli(folder, allocSize, size).getValue();
fullSizes.push_back(size);
partialSizes.push_back(folded_std_dim(folder, subView, rank));
}
SmallVector<int64_t, 4> dynSizes(fullRanges.size(), -1);
SmallVector<int64_t, 4> dynSizes(fullSizes.size(), -1);
auto buffer =
allocBuffer(viewType.getElementType(), allocSize, dynamicBuffers);
auto fullLocalView = std_view(
MemRefType::get(dynSizes, viewType.getElementType()), buffer, fullRanges);
auto partialLocalView = linalg_slice(fullLocalView, partialRanges);
allocBuffer(viewType.getElementType(), allocSize, dynamicBuffers, folder);
auto fullLocalView = folded_std_view(
folder, MemRefType::get(dynSizes, viewType.getElementType()), buffer,
fullSizes);
SmallVector<Value, 4> zeros(fullSizes.size(), zero);
SmallVector<Value, 4> ones(fullSizes.size(), one);
auto partialLocalView =
folded_std_subview(folder, fullLocalView, zeros, partialSizes, ones);
return PromotionInfo{buffer, fullLocalView, partialLocalView};
}

View File

@ -7,7 +7,6 @@
#map3 = affine_map<(d0) -> (d0 + 3)>
// CHECK-DAG: #[[strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
// CHECK-DAG: #[[strided2DnoOffset:.*]] = affine_map<(d0, d1)[s0] -> (d0 * s0 + d1)>
// CHECK-DAG: #[[strided2D_dynamic:.*]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
func @matmul_f32(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
@ -46,28 +45,28 @@ func @matmul_f32(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
// CHECK: %[[tmpA:.*]] = alloc() : memref<32xi8>
// CHECK: %[[fullA:.*]] = std.view %[[tmpA]][][{{.*}}] : memref<32xi8> to memref<?x?xf32>
// DYNAMIC: std.view %{{.*}}[][{{.*}}] : memref<?xi8> to memref<?x?xf32>
// CHECK: %[[partialA:.*]] = linalg.slice %[[fullA]][%{{.*}}, %{{.*}}] : memref<?x?xf32>, !linalg.range, !linalg.range, memref<?x?xf32, #[[strided2DnoOffset]]>
// CHECK: %[[partialA:.*]] = subview %[[fullA]]{{.*}} : memref<?x?xf32> to memref<?x?xf32, #[[strided2D_dynamic]]>
///
// CHECK: %[[tmpB:.*]] = alloc() : memref<48xi8>
// CHECK: %[[fullB:.*]] = std.view %[[tmpB]][][{{.*}}] : memref<48xi8> to memref<?x?xf32>
// DYNAMIC: std.view %{{.*}}[][{{.*}}] : memref<?xi8> to memref<?x?xf32>
// CHECK: %[[partialB:.*]] = linalg.slice %[[fullB]][%{{.*}}, %{{.*}}] : memref<?x?xf32>, !linalg.range, !linalg.range, memref<?x?xf32, #[[strided2DnoOffset]]>
// CHECK: %[[partialB:.*]] = subview %[[fullB]]{{.*}} : memref<?x?xf32> to memref<?x?xf32, #[[strided2D_dynamic]]>
///
// CHECK: %[[tmpC:.*]] = alloc() : memref<24xi8>
// CHECK: %[[fullC:.*]] = std.view %[[tmpC]][][{{.*}}] : memref<24xi8> to memref<?x?xf32>
// DYNAMIC: std.view %{{.*}}[][{{.*}}] : memref<?xi8> to memref<?x?xf32>
// CHECK: %[[partialC:.*]] = linalg.slice %[[fullC]][%{{.*}}, %{{.*}}] : memref<?x?xf32>, !linalg.range, !linalg.range, memref<?x?xf32, #[[strided2DnoOffset]]>
// CHECK: %[[partialC:.*]] = subview %[[fullC]]{{.*}} : memref<?x?xf32> to memref<?x?xf32, #[[strided2D_dynamic]]>
// CHECK: linalg.fill(%[[fullA]], {{.*}}) : memref<?x?xf32>, f32
// CHECK: linalg.fill(%[[fullB]], {{.*}}) : memref<?x?xf32>, f32
// CHECK: linalg.fill(%[[fullC]], {{.*}}) : memref<?x?xf32>, f32
// CHECK: linalg.copy(%[[vA]], %[[partialA]]) : memref<?x?xf32, #[[strided2D_dynamic]]>, memref<?x?xf32, #[[strided2DnoOffset]]>
// CHECK: linalg.copy(%[[vB]], %[[partialB]]) : memref<?x?xf32, #[[strided2D_dynamic]]>, memref<?x?xf32, #[[strided2DnoOffset]]>
// CHECK: linalg.copy(%[[vC]], %[[partialC]]) : memref<?x?xf32, #[[strided2D_dynamic]]>, memref<?x?xf32, #[[strided2DnoOffset]]>
// CHECK: linalg.copy(%[[vA]], %[[partialA]]) : memref<?x?xf32, #[[strided2D_dynamic]]>, memref<?x?xf32, #[[strided2D_dynamic]]>
// CHECK: linalg.copy(%[[vB]], %[[partialB]]) : memref<?x?xf32, #[[strided2D_dynamic]]>, memref<?x?xf32, #[[strided2D_dynamic]]>
// CHECK: linalg.copy(%[[vC]], %[[partialC]]) : memref<?x?xf32, #[[strided2D_dynamic]]>, memref<?x?xf32, #[[strided2D_dynamic]]>
//
// CHECK: linalg.matmul(%[[fullA]], %[[fullB]], %[[fullC]]) : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
//
// CHECK: linalg.copy(%[[partialC]], %[[vC]]) : memref<?x?xf32, #[[strided2DnoOffset]]>, memref<?x?xf32, #[[strided2D_dynamic]]>
// CHECK: linalg.copy(%[[partialC]], %[[vC]]) : memref<?x?xf32, #[[strided2D_dynamic]]>, memref<?x?xf32, #[[strided2D_dynamic]]>
//
// CHECK: dealloc %[[tmpA]] : memref<32xi8>
// CHECK: dealloc %[[tmpB]] : memref<48xi8>
@ -111,28 +110,28 @@ func @matmul_f64(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
// CHECK: %[[tmpA_f64:.*]] = alloc() : memref<64xi8>
// CHECK: %[[fullA_f64:.*]] = std.view %[[tmpA_f64]][][{{.*}}] : memref<64xi8> to memref<?x?xf64>
// DYNAMIC: std.view %{{.*}}[][{{.*}}] : memref<?xi8> to memref<?x?xf64>
// CHECK: %[[partialA_f64:.*]] = linalg.slice %[[fullA_f64]][%{{.*}}, %{{.*}}] : memref<?x?xf64>, !linalg.range, !linalg.range, memref<?x?xf64, #[[strided2DnoOffset]]>
// CHECK: %[[partialA_f64:.*]] = subview %[[fullA_f64]][%{{.*}}, %{{.*}}] : memref<?x?xf64> to memref<?x?xf64, #[[strided2D_dynamic]]>
///
// CHECK: %[[tmpB_f64:.*]] = alloc() : memref<96xi8>
// CHECK: %[[fullB_f64:.*]] = std.view %[[tmpB_f64]][][{{.*}}] : memref<96xi8> to memref<?x?xf64>
// DYNAMIC: std.view %{{.*}}[][{{.*}}] : memref<?xi8> to memref<?x?xf64>
// CHECK: %[[partialB_f64:.*]] = linalg.slice %[[fullB_f64]][%{{.*}}, %{{.*}}] : memref<?x?xf64>, !linalg.range, !linalg.range, memref<?x?xf64, #[[strided2DnoOffset]]>
// CHECK: %[[partialB_f64:.*]] = subview %[[fullB_f64]][%{{.*}}, %{{.*}}] : memref<?x?xf64> to memref<?x?xf64, #[[strided2D_dynamic]]>
///
// CHECK: %[[tmpC_f64:.*]] = alloc() : memref<48xi8>
// CHECK: %[[fullC_f64:.*]] = std.view %[[tmpC_f64]][][{{.*}}] : memref<48xi8> to memref<?x?xf64>
// DYNAMIC: std.view %{{.*}}[][{{.*}}] : memref<?xi8> to memref<?x?xf64>
// CHECK: %[[partialC_f64:.*]] = linalg.slice %[[fullC_f64]][%{{.*}}, %{{.*}}] : memref<?x?xf64>, !linalg.range, !linalg.range, memref<?x?xf64, #[[strided2DnoOffset]]>
// CHECK: %[[partialC_f64:.*]] = subview %[[fullC_f64]][%{{.*}}, %{{.*}}] : memref<?x?xf64> to memref<?x?xf64, #[[strided2D_dynamic]]>
// CHECK: linalg.fill(%[[fullA_f64]], {{.*}}) : memref<?x?xf64>, f64
// CHECK: linalg.fill(%[[fullB_f64]], {{.*}}) : memref<?x?xf64>, f64
// CHECK: linalg.fill(%[[fullC_f64]], {{.*}}) : memref<?x?xf64>, f64
// CHECK: linalg.copy(%[[vA_f64]], %[[partialA_f64]]) : memref<?x?xf64, #[[strided2D_dynamic]]>, memref<?x?xf64, #[[strided2DnoOffset]]>
// CHECK: linalg.copy(%[[vB_f64]], %[[partialB_f64]]) : memref<?x?xf64, #[[strided2D_dynamic]]>, memref<?x?xf64, #[[strided2DnoOffset]]>
// CHECK: linalg.copy(%[[vC_f64]], %[[partialC_f64]]) : memref<?x?xf64, #[[strided2D_dynamic]]>, memref<?x?xf64, #[[strided2DnoOffset]]>
// CHECK: linalg.copy(%[[vA_f64]], %[[partialA_f64]]) : memref<?x?xf64, #[[strided2D_dynamic]]>, memref<?x?xf64, #[[strided2D_dynamic]]>
// CHECK: linalg.copy(%[[vB_f64]], %[[partialB_f64]]) : memref<?x?xf64, #[[strided2D_dynamic]]>, memref<?x?xf64, #[[strided2D_dynamic]]>
// CHECK: linalg.copy(%[[vC_f64]], %[[partialC_f64]]) : memref<?x?xf64, #[[strided2D_dynamic]]>, memref<?x?xf64, #[[strided2D_dynamic]]>
//
// CHECK: linalg.matmul(%[[fullA_f64]], %[[fullB_f64]], %[[fullC_f64]]) : memref<?x?xf64>, memref<?x?xf64>, memref<?x?xf64>
//
// CHECK: linalg.copy(%[[partialC_f64]], %[[vC_f64]]) : memref<?x?xf64, #[[strided2DnoOffset]]>, memref<?x?xf64, #[[strided2D_dynamic]]>
// CHECK: linalg.copy(%[[partialC_f64]], %[[vC_f64]]) : memref<?x?xf64, #[[strided2D_dynamic]]>, memref<?x?xf64, #[[strided2D_dynamic]]>
//
// CHECK: dealloc %[[tmpA_f64]] : memref<64xi8>
// CHECK: dealloc %[[tmpB_f64]] : memref<96xi8>
@ -176,28 +175,28 @@ func @matmul_i32(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
// CHECK: %[[tmpA_i32:.*]] = alloc() : memref<32xi8>
// CHECK: %[[fullA_i32:.*]] = std.view %[[tmpA_i32]][][{{.*}}] : memref<32xi8> to memref<?x?xi32>
// DYNAMIC: std.view %{{.*}}[][{{.*}}] : memref<?xi8> to memref<?x?xi32>
// CHECK: %[[partialA_i32:.*]] = linalg.slice %[[fullA_i32]][%{{.*}}, %{{.*}}] : memref<?x?xi32>, !linalg.range, !linalg.range, memref<?x?xi32, #[[strided2DnoOffset]]>
// CHECK: %[[partialA_i32:.*]] = subview %[[fullA_i32]][%{{.*}}, %{{.*}}] : memref<?x?xi32> to memref<?x?xi32, #[[strided2D_dynamic]]>
///
// CHECK: %[[tmpB_i32:.*]] = alloc() : memref<48xi8>
// CHECK: %[[fullB_i32:.*]] = std.view %[[tmpB_i32]][][{{.*}}] : memref<48xi8> to memref<?x?xi32>
// DYNAMIC: std.view %{{.*}}[][{{.*}}] : memref<?xi8> to memref<?x?xi32>
// CHECK: %[[partialB_i32:.*]] = linalg.slice %[[fullB_i32]][%{{.*}}, %{{.*}}] : memref<?x?xi32>, !linalg.range, !linalg.range, memref<?x?xi32, #[[strided2DnoOffset]]>
// CHECK: %[[partialB_i32:.*]] = subview %[[fullB_i32]][%{{.*}}, %{{.*}}] : memref<?x?xi32> to memref<?x?xi32, #[[strided2D_dynamic]]>
///
// CHECK: %[[tmpC_i32:.*]] = alloc() : memref<24xi8>
// CHECK: %[[fullC_i32:.*]] = std.view %[[tmpC_i32]][][{{.*}}] : memref<24xi8> to memref<?x?xi32>
// DYNAMIC: std.view %{{.*}}[][{{.*}}] : memref<?xi8> to memref<?x?xi32>
// CHECK: %[[partialC_i32:.*]] = linalg.slice %[[fullC_i32]][%{{.*}}, %{{.*}}] : memref<?x?xi32>, !linalg.range, !linalg.range, memref<?x?xi32, #[[strided2DnoOffset]]>
// CHECK: %[[partialC_i32:.*]] = subview %[[fullC_i32]][%{{.*}}, %{{.*}}] : memref<?x?xi32> to memref<?x?xi32, #[[strided2D_dynamic]]>
// CHECK: linalg.fill(%[[fullA_i32]], {{.*}}) : memref<?x?xi32>, i32
// CHECK: linalg.fill(%[[fullB_i32]], {{.*}}) : memref<?x?xi32>, i32
// CHECK: linalg.fill(%[[fullC_i32]], {{.*}}) : memref<?x?xi32>, i32
// CHECK: linalg.copy(%[[vA_i32]], %[[partialA_i32]]) : memref<?x?xi32, #[[strided2D_dynamic]]>, memref<?x?xi32, #[[strided2DnoOffset]]>
// CHECK: linalg.copy(%[[vB_i32]], %[[partialB_i32]]) : memref<?x?xi32, #[[strided2D_dynamic]]>, memref<?x?xi32, #[[strided2DnoOffset]]>
// CHECK: linalg.copy(%[[vC_i32]], %[[partialC_i32]]) : memref<?x?xi32, #[[strided2D_dynamic]]>, memref<?x?xi32, #[[strided2DnoOffset]]>
// CHECK: linalg.copy(%[[vA_i32]], %[[partialA_i32]]) : memref<?x?xi32, #[[strided2D_dynamic]]>, memref<?x?xi32, #[[strided2D_dynamic]]>
// CHECK: linalg.copy(%[[vB_i32]], %[[partialB_i32]]) : memref<?x?xi32, #[[strided2D_dynamic]]>, memref<?x?xi32, #[[strided2D_dynamic]]>
// CHECK: linalg.copy(%[[vC_i32]], %[[partialC_i32]]) : memref<?x?xi32, #[[strided2D_dynamic]]>, memref<?x?xi32, #[[strided2D_dynamic]]>
//
// CHECK: linalg.matmul(%[[fullA_i32]], %[[fullB_i32]], %[[fullC_i32]]) : memref<?x?xi32>, memref<?x?xi32>, memref<?x?xi32>
//
// CHECK: linalg.copy(%[[partialC_i32]], %[[vC_i32]]) : memref<?x?xi32, #[[strided2DnoOffset]]>, memref<?x?xi32, #[[strided2D_dynamic]]>
// CHECK: linalg.copy(%[[partialC_i32]], %[[vC_i32]]) : memref<?x?xi32, #[[strided2D_dynamic]]>, memref<?x?xi32, #[[strided2D_dynamic]]>
//
// CHECK: dealloc %[[tmpA_i32]] : memref<32xi8>
// CHECK: dealloc %[[tmpB_i32]] : memref<48xi8>