forked from OSchip/llvm-project
[mlir][Linalg] NFC - Cleanup conv1d generators
Differential Revision: https://reviews.llvm.org/D117330
This commit is contained in:
parent
c7ca4c6365
commit
392e16c27f
|
@ -1287,6 +1287,22 @@ LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
|
|||
//===----------------------------------------------------------------------===//
|
||||
// Convolution vectorization patterns
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
template <int N>
|
||||
static void bindShapeDims(ShapedType shapedType) {}
|
||||
|
||||
template <int N, typename IntTy, typename... IntTy2>
|
||||
static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) {
|
||||
val = shapedType.getShape()[N];
|
||||
bindShapeDims<N + 1, IntTy2 &...>(shapedType, vals...);
|
||||
}
|
||||
|
||||
/// Bind a pack of int& to the leading dimensions of shapedType.getShape().
|
||||
template <typename... IntTy>
|
||||
static void bindShapeDims(ShapedType shapedType, IntTy &...vals) {
|
||||
bindShapeDims<0>(shapedType, vals...);
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// Generate a vector implementation for either:
|
||||
/// ```
|
||||
|
@ -1354,11 +1370,11 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
|
|||
if (!valid)
|
||||
return failure();
|
||||
|
||||
int nSize = lhsShapedType.getShape()[0];
|
||||
int wSize = resShapedType.getShape()[1];
|
||||
int cSize = lhsShapedType.getShape()[2];
|
||||
int kwSize = rhsShapedType.getShape()[0];
|
||||
int fSize = rhsShapedType.getShape()[2];
|
||||
int64_t nSize, wSize, cSize, kwSize, fSize;
|
||||
// kernel{kw, c, f}
|
||||
bindShapeDims(rhsShapedType, kwSize, cSize, fSize);
|
||||
// out{n, w, f}
|
||||
bindShapeDims(resShapedType, nSize, wSize);
|
||||
|
||||
vector::TransferWriteOp write;
|
||||
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
|
||||
|
@ -1398,31 +1414,29 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
|
|||
//===------------------------------------------------------------------===//
|
||||
// Unroll along kw and read slices of lhs and rhs.
|
||||
SmallVector<Value> lhsVals, rhsVals, resVals;
|
||||
// Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0].
|
||||
for (int64_t kw = 0; kw < kwSize; ++kw) {
|
||||
// Extract rhs slice of size {c, f} @ [kw].
|
||||
rhsVals.push_back(builder.create<vector::ExtractOp>(
|
||||
loc, rhs, /*offsets=*/ArrayRef<int64_t>{kw}));
|
||||
|
||||
for (int64_t w = 0; w < wSize; w += wSizeStep) {
|
||||
// Extract lhs slice of size {n, wSizeStep, c}
|
||||
// @ [0, sw * w + dw * kw, 0].
|
||||
lhsVals.push_back(builder.create<vector::ExtractStridedSliceOp>(
|
||||
loc, lhs,
|
||||
/*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
|
||||
/*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, cSize},
|
||||
/*strides=*/ArrayRef<int64_t>{1, 1, 1}));
|
||||
|
||||
// This does not depend on kw.
|
||||
if (kw == 0) {
|
||||
}
|
||||
}
|
||||
// Extract rhs slice of size {c, f} @ [kw].
|
||||
for (int64_t kw = 0; kw < kwSize; ++kw) {
|
||||
rhsVals.push_back(builder.create<vector::ExtractOp>(
|
||||
loc, rhs, /*offsets=*/ArrayRef<int64_t>{kw}));
|
||||
}
|
||||
// Extract res slice: {n, wSizeStep, f} @ [0, w, 0].
|
||||
for (int64_t w = 0; w < wSize; w += wSizeStep) {
|
||||
resVals.push_back(builder.create<vector::ExtractStridedSliceOp>(
|
||||
loc, res,
|
||||
/*offsets=*/ArrayRef<int64_t>{0, w, 0},
|
||||
/*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, fSize},
|
||||
/*strides=*/ArrayRef<int64_t>{1, 1, 1}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto linearIndex = [&](int64_t kw, int64_t w) {
|
||||
return kw * (wSize / wSizeStep) + w;
|
||||
|
@ -1476,14 +1490,15 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
|
|||
/// kw is always unrolled.
|
||||
/// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
|
||||
/// > 1.
|
||||
FailureOr<Operation *> dilatedConv() {
|
||||
FailureOr<Operation *> depthwiseConv() {
|
||||
if (!valid)
|
||||
return failure();
|
||||
|
||||
int nSize = lhsShapedType.getShape()[0];
|
||||
int wSize = resShapedType.getShape()[1];
|
||||
int cSize = lhsShapedType.getShape()[2];
|
||||
int kwSize = rhsShapedType.getShape()[0];
|
||||
int64_t nSize, wSize, cSize, kwSize;
|
||||
// kernel{kw, c}
|
||||
bindShapeDims(rhsShapedType, kwSize, cSize);
|
||||
// out{n, w, c}
|
||||
bindShapeDims(resShapedType, nSize, wSize);
|
||||
|
||||
vector::TransferWriteOp write;
|
||||
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
|
||||
|
@ -1522,31 +1537,30 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
|
|||
//===------------------------------------------------------------------===//
|
||||
// Unroll along kw and read slices of lhs and rhs.
|
||||
SmallVector<Value> lhsVals, rhsVals, resVals;
|
||||
for (int64_t kw = 0; kw < kwSize; ++kw) {
|
||||
// Extract rhs slice of size {c} @ [kw].
|
||||
rhsVals.push_back(builder.create<vector::ExtractOp>(
|
||||
loc, rhs, /*offsets=*/ArrayRef<int64_t>{kw}));
|
||||
|
||||
for (int64_t w = 0; w < wSize; w += wSizeStep) {
|
||||
// Extract lhs slice of size {n, wSizeStep, c}
|
||||
// @ [0, sw * w + dw * kw, 0].
|
||||
for (int64_t kw = 0; kw < kwSize; ++kw) {
|
||||
for (int64_t w = 0; w < wSize; w += wSizeStep) {
|
||||
lhsVals.push_back(builder.create<vector::ExtractStridedSliceOp>(
|
||||
loc, lhs,
|
||||
/*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
|
||||
/*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, cSize},
|
||||
/*strides=*/ArrayRef<int64_t>{1, 1, 1}));
|
||||
|
||||
// This does not depend on kw.
|
||||
if (kw == 0) {
|
||||
}
|
||||
}
|
||||
// Extract rhs slice of size {c} @ [kw].
|
||||
for (int64_t kw = 0; kw < kwSize; ++kw) {
|
||||
rhsVals.push_back(builder.create<vector::ExtractOp>(
|
||||
loc, rhs, /*offsets=*/ArrayRef<int64_t>{kw}));
|
||||
}
|
||||
// Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
|
||||
for (int64_t w = 0; w < wSize; w += wSizeStep) {
|
||||
resVals.push_back(builder.create<vector::ExtractStridedSliceOp>(
|
||||
loc, res,
|
||||
/*offsets=*/ArrayRef<int64_t>{0, w, 0},
|
||||
/*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, cSize},
|
||||
/*strides=*/ArrayRef<int64_t>{1, 1, 1}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto linearIndex = [&](int64_t kw, int64_t w) {
|
||||
return kw * (wSize / wSizeStep) + w;
|
||||
|
@ -1555,7 +1569,7 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
|
|||
// Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
|
||||
for (int64_t kw = 0; kw < kwSize; ++kw) {
|
||||
for (int64_t w = 0; w < wSize; w += wSizeStep) {
|
||||
resVals[w] = dilatedConv1dSliceAsFma(
|
||||
resVals[w] = depthwiseConv1dSliceAsFma(
|
||||
builder, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw], resVals[w]);
|
||||
}
|
||||
}
|
||||
|
@ -1580,7 +1594,7 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
|
|||
}
|
||||
|
||||
/// Lower lhs{n, w, c} * rhs{c} -> res{n, w, c} to fma.
|
||||
Value dilatedConv1dSliceAsFma(OpBuilder &b, Location loc, Value lhs,
|
||||
Value depthwiseConv1dSliceAsFma(OpBuilder &b, Location loc, Value lhs,
|
||||
Value rhs, Value res) {
|
||||
Value bcast = builder.create<vector::BroadcastOp>(loc, res.getType(), rhs);
|
||||
return b.create<vector::FMAOp>(loc, lhs, bcast, res);
|
||||
|
@ -1614,7 +1628,7 @@ struct Conv1DNwcGenerator : public StructuredGenerator<LinalgOp> {
|
|||
if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
|
||||
/*rhsIndex*/ {kw, c},
|
||||
/*resIndex*/ {n, w, c}}))
|
||||
return dilatedConv();
|
||||
return depthwiseConv();
|
||||
return failure();
|
||||
}
|
||||
|
||||
|
|
|
@ -23,15 +23,15 @@ func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<1x3x8xf
|
|||
// CHECK-DAG: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
|
||||
// CHECK-DAG: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
|
||||
|
||||
// CHECK: %[[V_FILTER:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<1x3x8xf32>
|
||||
/// w == 0, kw == 0
|
||||
// CHECK: %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
|
||||
// CHECK-SAME: {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x1x3xf32>
|
||||
// CHECK: %[[V_OUTPUT_0:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
|
||||
// CHECK-SAME: {offsets = [0, 0, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32>
|
||||
/// w == 1, kw == 0
|
||||
// CHECK: %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
|
||||
// CHECK-SAME: {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x1x3xf32>
|
||||
|
||||
// CHECK: %[[V_FILTER:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<1x3x8xf32>
|
||||
|
||||
// CHECK: %[[V_OUTPUT_0:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
|
||||
// CHECK-SAME: {offsets = [0, 0, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32>
|
||||
// CHECK: %[[V_OUTPUT_1:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
|
||||
// CHECK-SAME: {offsets = [0, 1, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32>
|
||||
|
||||
|
@ -84,27 +84,23 @@ func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<2x3x8xf
|
|||
// CHECK-DAG: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
|
||||
// CHECK-DAG: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
|
||||
|
||||
|
||||
/// w == 0, kw == 0
|
||||
// CHECK: %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<2x3x8xf32>
|
||||
// CHECK: %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
|
||||
// CHECK-SAME: {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32>
|
||||
// CHECK: %[[V_OUTPUT_0:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
|
||||
// CHECK-SAME: {offsets = [0, 0, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32>
|
||||
/// w == 1, kw == 0
|
||||
// CHECK: %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
|
||||
// CHECK-SAME: {offsets = [0, 3, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32>
|
||||
// CHECK: %[[V_OUTPUT_1:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
|
||||
// CHECK-SAME: {offsets = [0, 1, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32>
|
||||
|
||||
/// w == 0, kw == 1
|
||||
// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<2x3x8xf32>
|
||||
// CHECK: %[[V_INPUT_2:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
|
||||
// CHECK-SAME: {offsets = [0, 2, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32>
|
||||
/// w == 1, kw == 0
|
||||
// CHECK: %[[V_INPUT_3:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
|
||||
// CHECK-SAME: {offsets = [0, 5, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x6x3xf32> to vector<4x1x3xf32>
|
||||
|
||||
// CHECK: %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<2x3x8xf32>
|
||||
// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<2x3x8xf32>
|
||||
|
||||
// CHECK: %[[V_OUTPUT_0:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
|
||||
// CHECK-SAME: {offsets = [0, 0, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32>
|
||||
// CHECK: %[[V_OUTPUT_1:.+]] = vector.extract_strided_slice %[[V_OUTPUT_R]]
|
||||
// CHECK-SAME: {offsets = [0, 1, 0], sizes = [4, 1, 8], strides = [1, 1, 1]} : vector<4x2x8xf32> to vector<4x1x8xf32>
|
||||
|
||||
/// w == 0, kw == 0
|
||||
// CHECK: %[[CONTRACT_0:.+]] = vector.contract {
|
||||
// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
|
||||
|
@ -165,15 +161,14 @@ func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<2x3x8xf
|
|||
// CHECK-DAG: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
|
||||
// CHECK-DAG: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]], %[[F0]]
|
||||
|
||||
/// w == 0, kw == 0
|
||||
// CHECK: %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<2x3x8xf32>
|
||||
// CHECK: %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
|
||||
// CHECK-SAME: {offsets = [0, 0, 0], sizes = [4, 2, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x2x3xf32>
|
||||
/// w == 0, kw == 1
|
||||
// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<2x3x8xf32>
|
||||
// CHECK: %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
|
||||
// CHECK-SAME: {offsets = [0, 2, 0], sizes = [4, 2, 3], strides = [1, 1, 1]} : vector<4x4x3xf32> to vector<4x2x3xf32>
|
||||
|
||||
// CHECK: %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<2x3x8xf32>
|
||||
// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<2x3x8xf32>
|
||||
|
||||
/// w == 0, kw == 0
|
||||
// CHECK: %[[CONTRACT_0:.+]] = vector.contract {
|
||||
// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
|
||||
|
@ -211,15 +206,14 @@ func @depthwise_conv1d_nwc_wc_3x5x4_memref(%input: memref<3x5x4xf32>, %filter: m
|
|||
// CHECK: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]]]
|
||||
// CHECK: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
|
||||
|
||||
/// w == 0, kw == 0
|
||||
// CHECK: %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<2x4xf32>
|
||||
// CHECK: %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
|
||||
// CHECK-SAME: {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32>
|
||||
/// w == 0, kw == 1
|
||||
// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<2x4xf32>
|
||||
// CHECK: %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
|
||||
// CHECK-SAME: {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32>
|
||||
|
||||
// CHECK: %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<2x4xf32>
|
||||
// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<2x4xf32>
|
||||
|
||||
/// w == 0, kw = 0
|
||||
// CHECK: %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xf32> to vector<3x2x4xf32>
|
||||
// CHECK: %[[FMA_0:.*]] = vector.fma %[[V_INPUT_0]], %[[B_FILTER_0]], %[[V_OUTPUT_R]] : vector<3x2x4xf32>
|
||||
|
|
Loading…
Reference in New Issue