[mlir][Linalg] Reduction dimensions specified in TC definition of ConvOps.

This commit specifies reduction dimensions for ConvOps. This prevents
running reduction loops in parallel and enables easier detection of kernel dimensions
which we will need later on.

Differential Revision: https://reviews.llvm.org/D87288
This commit is contained in:
Jakub Lichman 2020-09-08 15:04:35 +00:00
parent e706116e11
commit 53ffeea6d5
2 changed files with 47 additions and 43 deletions

View File

@ -20,52 +20,50 @@ def batch_matmul(A: f32(Batch, M, K), B: f32(Batch, K, N)) -> (C: f32(Batch, M,
ods_def<ConvWOp>: ods_def<ConvWOp>:
def conv_1d(I: f32(W), K: f32(KW)) -> (O: f32(W)) { def conv_1d(I: f32(W), K: f32(KW)) -> (O: f32(W)) {
O(w) = std_addf(O(w), std_mulf(I(w + kw), K(kw))); O(w) = std_addf<kw>(std_mulf(I(w + kw), K(kw)));
} }
ods_def<ConvNWCOp>: ods_def<ConvNWCOp>:
def conv_1d_nwc(I: f32(N, W, C), K: f32(F, KW, C)) -> (O: f32(N, W, F)) { def conv_1d_nwc(I: f32(N, W, C), K: f32(F, KW, C)) -> (O: f32(N, W, F)) {
O(n, w, f) = std_addf(O(n, w, f), O(n, w, f) = std_addf<kw>(std_mulf(I(n, w + kw, c), K(f, kw, c)));
std_mulf(I(n, w + kw, c), K(f, kw, c)));
} }
ods_def<ConvNCWOp>: ods_def<ConvNCWOp>:
def conv_1d_ncw(I: f32(N, C, W), K: f32(F, C, KW)) -> (O: f32(N, F, W)) { def conv_1d_ncw(I: f32(N, C, W), K: f32(F, C, KW)) -> (O: f32(N, F, W)) {
O(n, f, w) = std_addf(O(n, f, w), O(n, f, w) = std_addf<kw>(std_mulf(I(n, c, w + kw), K(f, c, kw)));
std_mulf(I(n, c, w + kw), K(f, c, kw)));
} }
ods_def<ConvHWOp>: ods_def<ConvHWOp>:
def conv_2d(I: f32(H, W), K: f32(KH, KW)) -> (O: f32(H, W)) { def conv_2d(I: f32(H, W), K: f32(KH, KW)) -> (O: f32(H, W)) {
O(h, w) = std_addf(O(h, w), std_mulf(I(h + kh, w + kw), K(kh, kw))); O(h, w) = std_addf<kh, kw>(std_mulf(I(h + kh, w + kw), K(kh, kw)));
} }
ods_def<ConvNHWCOp>: ods_def<ConvNHWCOp>:
def conv_2d_nhwc(I: f32(N, H, W, C), K: f32(F, KH, KW, C)) -> (O: f32(N, H, W, F)) { def conv_2d_nhwc(I: f32(N, H, W, C), K: f32(F, KH, KW, C)) -> (O: f32(N, H, W, F)) {
O(n, h, w, f) = std_addf(O(n, h, w, f), O(n, h, w, f) = std_addf<kh, kw>(std_mulf(
std_mulf(I(n, h + kh, w + kw, c), K(f, kh, kw, c))); I(n, h + kh, w + kw, c), K(f, kh, kw, c)));
} }
ods_def<ConvNCHWOp>: ods_def<ConvNCHWOp>:
def conv_2d_nchw(I: f32(N, C, H, W), K: f32(F, C, KH, KW)) -> (O: f32(N, F, H, W)) { def conv_2d_nchw(I: f32(N, C, H, W), K: f32(F, C, KH, KW)) -> (O: f32(N, F, H, W)) {
O(n, f, h, w) = std_addf(O(n, f, h, w), O(n, f, h, w) = std_addf<kh, kw>(std_mulf(
std_mulf(I(n, c, h + kh, w + kw), K(f, c, kh, kw))); I(n, c, h + kh, w + kw), K(f, c, kh, kw)));
} }
ods_def<ConvDHWOp>: ods_def<ConvDHWOp>:
def conv_3d(I: f32(D, H, W), K: f32(KD, KH, KW)) -> (O: f32(D, H, W)) { def conv_3d(I: f32(D, H, W), K: f32(KD, KH, KW)) -> (O: f32(D, H, W)) {
O(d, h, w) = std_addf(O(d, h, w), O(d, h, w) = std_addf<kd, kh, kw>(std_mulf(
std_mulf(I(d + kd, h + kh, w + kw), K(kd, kh, kw))); I(d + kd, h + kh, w + kw), K(kd, kh, kw)));
} }
ods_def<ConvNDHWCOp>: ods_def<ConvNDHWCOp>:
def conv_3d_ndhwc(I: f32(N, D, H, W, C), K: f32(F, KD, KH, KW, C)) -> (O: f32(N, D, H, W, F)) { def conv_3d_ndhwc(I: f32(N, D, H, W, C), K: f32(F, KD, KH, KW, C)) -> (O: f32(N, D, H, W, F)) {
O(n, d, h, w, f) = std_addf(O(n, d, h, w, f), O(n, d, h, w, f) = std_addf<kd, kh, kw>(std_mulf(
std_mulf(I(n, d + kd, h + kh, w + kw, c), K(f, kd, kh, kw, c))); I(n, d + kd, h + kh, w + kw, c), K(f, kd, kh, kw, c)));
} }
ods_def<ConvNCDHWOp>: ods_def<ConvNCDHWOp>:
def conv_3d_ncdhw(I: f32(N, C, D, H, W), K: f32(F, C, KD, KH, KW)) -> (O: f32(N, F, D, H, W)) { def conv_3d_ncdhw(I: f32(N, C, D, H, W), K: f32(F, C, KD, KH, KW)) -> (O: f32(N, F, D, H, W)) {
O(n, f, d, h, w) = std_addf(O(n, f, d, h, w), O(n, f, d, h, w) = std_addf<kd, kh, kw>(std_mulf(
std_mulf(I(n, c, d + kd, h + kh, w + kw), K(f, c, kd, kh, kw))); I(n, c, d + kd, h + kh, w + kw), K(f, c, kd, kh, kw)));
} }

View File

@ -1318,7 +1318,8 @@ func @conv1d_no_symbols(%in : memref<?xf32>, %filter : memref<?xf32>, %out : mem
// CHECKPARALLEL: %[[c1:.*]] = constant 1 : index // CHECKPARALLEL: %[[c1:.*]] = constant 1 : index
// CHECKPARALLEL: %[[dim0:.*]] = dim %[[arg1]], %[[c0]] : memref<?xf32> // CHECKPARALLEL: %[[dim0:.*]] = dim %[[arg1]], %[[c0]] : memref<?xf32>
// CHECKPARALLEL: %[[dim1:.*]] = dim %[[arg2]], %[[c0]] : memref<?xf32> // CHECKPARALLEL: %[[dim1:.*]] = dim %[[arg2]], %[[c0]] : memref<?xf32>
// CHECKPARALLEL: scf.parallel (%[[b:.*]], %[[m:.*]]) = (%[[c0]], %[[c0]]) to (%[[dim1]], %[[dim0]]) step (%[[c1]], %[[c1]]) { // CHECKPARALLEL: scf.parallel (%[[b:.*]]) = (%[[c0]]) to (%[[dim1]]) step (%[[c1]]) {
// CHECKPARALLEL: scf.for %[[m:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] {
// CHECKPARALLEL: %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[b]], %[[m]]) // CHECKPARALLEL: %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[b]], %[[m]])
// CHECKPARALLEL: %[[vb:.*]] = load %[[arg0]][%[[aff]]] : memref<?xf32> // CHECKPARALLEL: %[[vb:.*]] = load %[[arg0]][%[[aff]]] : memref<?xf32>
// CHECKPARALLEL: %[[va:.*]] = load %[[arg1]][%[[m]]] : memref<?xf32> // CHECKPARALLEL: %[[va:.*]] = load %[[arg1]][%[[m]]] : memref<?xf32>
@ -1367,7 +1368,9 @@ func @conv2d_no_symbols(%in : memref<?x?xf32>, %filter : memref<?x?xf32>, %out :
// CHECKPARALLEL: %[[dim1:.*]] = dim %[[arg1]], %[[c1]] : memref<?x?xf32> // CHECKPARALLEL: %[[dim1:.*]] = dim %[[arg1]], %[[c1]] : memref<?x?xf32>
// CHECKPARALLEL: %[[dim2:.*]] = dim %[[arg2]], %[[c0]] : memref<?x?xf32> // CHECKPARALLEL: %[[dim2:.*]] = dim %[[arg2]], %[[c0]] : memref<?x?xf32>
// CHECKPARALLEL: %[[dim3:.*]] = dim %[[arg2]], %[[c1]] : memref<?x?xf32> // CHECKPARALLEL: %[[dim3:.*]] = dim %[[arg2]], %[[c1]] : memref<?x?xf32>
// CHECKPARALLEL: scf.parallel (%[[arg3:.*]], %[[arg4:.*]], %[[arg5:.*]], %[[arg6:.*]]) = (%[[c0]], %[[c0]], %[[c0]], %[[c0]]) to (%[[dim2]], %[[dim3]], %[[dim0]], %[[dim1]]) step (%[[c1]], %[[c1]], %[[c1]], %[[c1]]) { // CHECKPARALLEL: scf.parallel (%[[arg3:.*]], %[[arg4:.*]]) = (%[[c0]], %[[c0]]) to (%[[dim2]], %[[dim3]]) step (%[[c1]], %[[c1]]) {
// CHECKPARALLEL: scf.for %[[arg5:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] {
// CHECKPARALLEL: scf.for %[[arg6:.*]] = %[[c0]] to %[[dim1]] step %[[c1]] {
// CHECKPARALLEL: %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg3]], %[[arg5]]) // CHECKPARALLEL: %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg3]], %[[arg5]])
// CHECKPARALLEL: %[[aff2:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg4]], %[[arg6]]) // CHECKPARALLEL: %[[aff2:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg4]], %[[arg6]])
// CHECKPARALLEL: %[[vb:.*]] = load %[[arg0]][%[[aff]], %[[aff2]]] : memref<?x?xf32> // CHECKPARALLEL: %[[vb:.*]] = load %[[arg0]][%[[aff]], %[[aff2]]] : memref<?x?xf32>
@ -1427,7 +1430,10 @@ func @conv3d_no_symbols(%in : memref<?x?x?xf32>, %filter : memref<?x?x?xf32>, %o
// CHECKPARALLEL: %[[dim3:.*]] = dim %[[arg2]], %[[c0]] : memref<?x?x?xf32> // CHECKPARALLEL: %[[dim3:.*]] = dim %[[arg2]], %[[c0]] : memref<?x?x?xf32>
// CHECKPARALLEL: %[[dim4:.*]] = dim %[[arg2]], %[[c1]] : memref<?x?x?xf32> // CHECKPARALLEL: %[[dim4:.*]] = dim %[[arg2]], %[[c1]] : memref<?x?x?xf32>
// CHECKPARALLEL: %[[dim5:.*]] = dim %[[arg2]], %[[c2]] : memref<?x?x?xf32> // CHECKPARALLEL: %[[dim5:.*]] = dim %[[arg2]], %[[c2]] : memref<?x?x?xf32>
// CHECKPARALLEL: scf.parallel (%[[arg3:.*]], %[[arg4:.*]], %[[arg5:.*]], %[[arg6:.*]], %[[arg7:.*]], %[[arg8:.*]]) = (%[[c0]], %[[c0]], %[[c0]], %[[c0]], %[[c0]], %[[c0]]) to (%[[dim3]], %[[dim4]], %[[dim5]], %[[dim0]], %[[dim1]], %[[dim2]]) step (%[[c1]], %[[c1]], %[[c1]], %[[c1]], %[[c1]], %[[c1]]) { // CHECKPARALLEL: scf.parallel (%[[arg3:.*]], %[[arg4:.*]], %[[arg5:.*]]) = (%[[c0]], %[[c0]], %[[c0]]) to (%[[dim3]], %[[dim4]], %[[dim5]]) step (%[[c1]], %[[c1]], %[[c1]]) {
// CHECKPARALLEL: scf.for %[[arg6:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] {
// CHECKPARALLEL: scf.for %[[arg7:.*]] = %[[c0]] to %[[dim1]] step %[[c1]] {
// CHECKPARALLEL: scf.for %[[arg8:.*]] = %[[c0]] to %[[dim2]] step %[[c1]] {
// CHECKPARALLEL: %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg3]], %[[arg6]]) // CHECKPARALLEL: %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg3]], %[[arg6]])
// CHECKPARALLEL: %[[aff2:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg4]], %[[arg7]]) // CHECKPARALLEL: %[[aff2:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg4]], %[[arg7]])
// CHECKPARALLEL: %[[aff3:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg5]], %[[arg8]]) // CHECKPARALLEL: %[[aff3:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg5]], %[[arg8]])