[mlir][Linalg] NFC - Cleanup conv1d generators

Differential Revision: https://reviews.llvm.org/D117330
This commit is contained in:
Nicolas Vasilache 2022-01-14 17:18:37 +00:00
parent c7ca4c6365
commit 392e16c27f
2 changed files with 79 additions and 71 deletions

View File

@ -1287,6 +1287,22 @@ LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
//===----------------------------------------------------------------------===//
// Convolution vectorization patterns
//===----------------------------------------------------------------------===//
template <int N>
static void bindShapeDims(ShapedType shapedType) {}
template <int N, typename IntTy, typename... IntTy2>
static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) {
val = shapedType.getShape()[N];
bindShapeDims<N + 1, IntTy2 &...>(shapedType, vals...);
}
/// Bind a pack of int& to the leading dimensions of shapedType.getShape().
template <typename... IntTy>
static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
bindShapeDims<0>(shapedType, vals...);
}
namespace {
/// Generate a vector implementation for either:
/// ```
@ -1354,11 +1370,11 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
if (!valid)
return failure();
int nSize = lhsShapedType.getShape()[0];
int wSize = resShapedType.getShape()[1];
int cSize = lhsShapedType.getShape()[2];
int kwSize = rhsShapedType.getShape()[0];
int fSize = rhsShapedType.getShape()[2];
int64_t nSize, wSize, cSize, kwSize, fSize;
// kernel{kw, c, f}
bindShapeDims(rhsShapedType, kwSize, cSize, fSize);
// out{n, w, f}
bindShapeDims(resShapedType, nSize, wSize);
vector::TransferWriteOp write;
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
@ -1398,31 +1414,29 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
//===------------------------------------------------------------------===//
// Unroll along kw and read slices of lhs and rhs.
SmallVector<Value> lhsVals, rhsVals, resVals;
// Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0].
for (int64_t kw = 0; kw < kwSize; ++kw) {
// Extract rhs slice of size {c, f} @ [kw].
rhsVals.push_back(builder.create<vector::ExtractOp>(
loc, rhs, /*offsets=*/ArrayRef<int64_t>{kw}));
for (int64_t w = 0; w < wSize; w += wSizeStep) {
// Extract lhs slice of size {n, wSizeStep, c}
// @ [0, sw * w + dw * kw, 0].
lhsVals.push_back(builder.create<vector::ExtractStridedSliceOp>(
loc, lhs,
/*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
/*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, cSize},
/*strides=*/ArrayRef<int64_t>{1, 1, 1}));
// This does not depend on kw.
if (kw == 0) {
}
}
// Extract rhs slice of size {c, f} @ [kw].
for (int64_t kw = 0; kw < kwSize; ++kw) {
rhsVals.push_back(builder.create<vector::ExtractOp>(
loc, rhs, /*offsets=*/ArrayRef<int64_t>{kw}));
}
// Extract res slice: {n, wSizeStep, f} @ [0, w, 0].
for (int64_t w = 0; w < wSize; w += wSizeStep) {
resVals.push_back(builder.create<vector::ExtractStridedSliceOp>(
loc, res,
/*offsets=*/ArrayRef<int64_t>{0, w, 0},
/*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, fSize},
/*strides=*/ArrayRef<int64_t>{1, 1, 1}));
}
}
}
auto linearIndex = [&](int64_t kw, int64_t w) {
return kw * (wSize / wSizeStep) + w;
@ -1476,14 +1490,15 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
/// kw is always unrolled.
/// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
/// > 1.
FailureOr<Operation *> dilatedConv() {
FailureOr<Operation *> depthwiseConv() {
if (!valid)
return failure();
int nSize = lhsShapedType.getShape()[0];
int wSize = resShapedType.getShape()[1];
int cSize = lhsShapedType.getShape()[2];
int kwSize = rhsShapedType.getShape()[0];
int64_t nSize, wSize, cSize, kwSize;
// kernel{kw, c}
bindShapeDims(rhsShapedType, kwSize, cSize);
// out{n, w, c}
bindShapeDims(resShapedType, nSize, wSize);
vector::TransferWriteOp write;
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
@ -1522,31 +1537,30 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
//===------------------------------------------------------------------===//
// Unroll along kw and read slices of lhs and rhs.
SmallVector<Value> lhsVals, rhsVals, resVals;
for (int64_t kw = 0; kw < kwSize; ++kw) {
// Extract rhs slice of size {c} @ [kw].
rhsVals.push_back(builder.create<vector::ExtractOp>(
loc, rhs, /*offsets=*/ArrayRef<int64_t>{kw}));
for (int64_t w = 0; w < wSize; w += wSizeStep) {
// Extract lhs slice of size {n, wSizeStep, c}
// @ [0, sw * w + dw * kw, 0].
for (int64_t kw = 0; kw < kwSize; ++kw) {
for (int64_t w = 0; w < wSize; w += wSizeStep) {
lhsVals.push_back(builder.create<vector::ExtractStridedSliceOp>(
loc, lhs,
/*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
/*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, cSize},
/*strides=*/ArrayRef<int64_t>{1, 1, 1}));
// This does not depend on kw.
if (kw == 0) {
}
}
// Extract rhs slice of size {c} @ [kw].
for (int64_t kw = 0; kw < kwSize; ++kw) {
rhsVals.push_back(builder.create<vector::ExtractOp>(
loc, rhs, /*offsets=*/ArrayRef<int64_t>{kw}));
}
// Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
for (int64_t w = 0; w < wSize; w += wSizeStep) {
resVals.push_back(builder.create<vector::ExtractStridedSliceOp>(
loc, res,
/*offsets=*/ArrayRef<int64_t>{0, w, 0},
/*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, cSize},
/*strides=*/ArrayRef<int64_t>{1, 1, 1}));
}
}
}
auto linearIndex = [&](int64_t kw, int64_t w) {
return kw * (wSize / wSizeStep) + w;
@ -1555,7 +1569,7 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
// Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
for (int64_t kw = 0; kw < kwSize; ++kw) {
for (int64_t w = 0; w < wSize; w += wSizeStep) {
resVals[w] = dilatedConv1dSliceAsFma(
resVals[w] = depthwiseConv1dSliceAsFma(
builder, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw], resVals[w]);
}
}
@ -1580,7 +1594,7 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
}
/// Lower lhs{n, w, c} * rhs{c} -> res{n, w, c} to fma.
Value dilatedConv1dSliceAsFma(OpBuilder &b, Location loc, Value lhs,
Value depthwiseConv1dSliceAsFma(OpBuilder &b, Location loc, Value lhs,
Value rhs, Value res) {
Value bcast = builder.create<vector::BroadcastOp>(loc, res.getType(), rhs);
return b.create<vector::FMAOp>(loc, lhs, bcast, res);
@ -1614,7 +1628,7 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
/*rhsIndex*/ {kw, c},
/*resIndex*/ {n, w, c}}))
return dilatedConv();
return depthwiseConv();
return failure();
}

View File

@ -23,15 +23,15 @@ func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<1x3x8xf
// CHECK-DAG: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
// CHECK-DAG: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
// CHECK: %[[V_FILTER:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<1x3x8xf32>
/// w == 0, kw == 0
// CHECK: %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
// CHECK-SAME: {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x1x3xf32>
// CHECK: %[[V_OUTPUT_0:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
// CHECK-SAME: {offsets = [0, 0, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32>
/// w == 1, kw == 0
// CHECK: %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
// CHECK-SAME: {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x1x3xf32>
// CHECK: %[[V_FILTER:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<1x3x8xf32>
// CHECK: %[[V_OUTPUT_0:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
// CHECK-SAME: {offsets = [0, 0, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32>
// CHECK: %[[V_OUTPUT_1:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
// CHECK-SAME: {offsets = [0, 1, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32>
@ -84,27 +84,23 @@ func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<2x3x8xf
// CHECK-DAG: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
// CHECK-DAG: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
/// w == 0, kw == 0
// CHECK: %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<2x3x8xf32>
// CHECK: %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
// CHECK-SAME: {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32>
// CHECK: %[[V_OUTPUT_0:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
// CHECK-SAME: {offsets = [0, 0, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32>
/// w == 1, kw == 0
// CHECK: %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
// CHECK-SAME: {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32>
// CHECK: %[[V_OUTPUT_1:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
// CHECK-SAME: {offsets = [0, 1, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32>
/// w == 0, kw == 1
// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<2x3x8xf32>
// CHECK: %[[V_INPUT_2:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
// CHECK-SAME: {offsets = [0, 2, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32>
/// w == 1, kw == 0
// CHECK: %[[V_INPUT_3:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
// CHECK-SAME: {offsets = [0, 5, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32>
// CHECK: %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<2x3x8xf32>
// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<2x3x8xf32>
// CHECK: %[[V_OUTPUT_0:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
// CHECK-SAME: {offsets = [0, 0, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32>
// CHECK: %[[V_OUTPUT_1:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
// CHECK-SAME: {offsets = [0, 1, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32>
/// w == 0, kw == 0
// CHECK: %[[CONTRACT_0:.+]] = vector.contract {
// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
@ -165,15 +161,14 @@ func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<2x3x8xf
// CHECK-DAG: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
// CHECK-DAG: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
/// w == 0, kw == 0
// CHECK: %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<2x3x8xf32>
// CHECK: %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
// CHECK-SAME: {offsets = [0, 0, 0], sizes = [4, 2, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x2x3xf32>
/// w == 0, kw == 1
// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<2x3x8xf32>
// CHECK: %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
// CHECK-SAME: {offsets = [0, 2, 0], sizes = [4, 2, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x2x3xf32>
// CHECK: %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<2x3x8xf32>
// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<2x3x8xf32>
/// w == 0, kw == 0
// CHECK: %[[CONTRACT_0:.+]] = vector.contract {
// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
@ -211,15 +206,14 @@ func @depthwise_conv1d_nwc_wc_3x5x4_memref(%input: memref<3x5x4xf32>, %filter: m
// CHECK: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]]]
// CHECK: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
/// w == 0, kw == 0
// CHECK: %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<2x4xf32>
// CHECK: %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
// CHECK-SAME: {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32>
/// w == 0, kw == 1
// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<2x4xf32>
// CHECK: %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
// CHECK-SAME: {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32>
// CHECK: %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<2x4xf32>
// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<2x4xf32>
/// w == 0, kw = 0
// CHECK: %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xf32> to vector<3x2x4xf32>
// CHECK: %[[FMA_0:.*]] = vector.fma %[[V_INPUT_0]], %[[B_FILTER_0]], %[[V_OUTPUT_R]] : vector<3x2x4xf32>