forked from OSchip/llvm-project
[mlir][Linalg] Reduction dimensions specified in TC definition of ConvOps.
This commit specifies reduction dimensions for ConvOps. This prevents running reduction loops in parallel and enables easier detection of kernel dimensions which we will need later on. Differential Revision: https://reviews.llvm.org/D87288
This commit is contained in:
parent
e706116e11
commit
53ffeea6d5
|
@ -20,52 +20,50 @@ def batch_matmul(A: f32(Batch, M, K), B: f32(Batch, K, N)) -> (C: f32(Batch, M,
|
|||
|
||||
ods_def<ConvWOp>:
|
||||
def conv_1d(I: f32(W), K: f32(KW)) -> (O: f32(W)) {
|
||||
O(w) = std_addf(O(w), std_mulf(I(w + kw), K(kw)));
|
||||
O(w) = std_addf<kw>(std_mulf(I(w + kw), K(kw)));
|
||||
}
|
||||
|
||||
ods_def<ConvNWCOp>:
|
||||
def conv_1d_nwc(I: f32(N, W, C), K: f32(F, KW, C)) -> (O: f32(N, W, F)) {
|
||||
O(n, w, f) = std_addf(O(n, w, f),
|
||||
std_mulf(I(n, w + kw, c), K(f, kw, c)));
|
||||
O(n, w, f) = std_addf<kw>(std_mulf(I(n, w + kw, c), K(f, kw, c)));
|
||||
}
|
||||
|
||||
ods_def<ConvNCWOp>:
|
||||
def conv_1d_ncw(I: f32(N, C, W), K: f32(F, C, KW)) -> (O: f32(N, F, W)) {
|
||||
O(n, f, w) = std_addf(O(n, f, w),
|
||||
std_mulf(I(n, c, w + kw), K(f, c, kw)));
|
||||
O(n, f, w) = std_addf<kw>(std_mulf(I(n, c, w + kw), K(f, c, kw)));
|
||||
}
|
||||
|
||||
ods_def<ConvHWOp>:
|
||||
def conv_2d(I: f32(H, W), K: f32(KH, KW)) -> (O: f32(H, W)) {
|
||||
O(h, w) = std_addf(O(h, w), std_mulf(I(h + kh, w + kw), K(kh, kw)));
|
||||
O(h, w) = std_addf<kh, kw>(std_mulf(I(h + kh, w + kw), K(kh, kw)));
|
||||
}
|
||||
|
||||
ods_def<ConvNHWCOp>:
|
||||
def conv_2d_nhwc(I: f32(N, H, W, C), K: f32(F, KH, KW, C)) -> (O: f32(N, H, W, F)) {
|
||||
O(n, h, w, f) = std_addf(O(n, h, w, f),
|
||||
std_mulf(I(n, h + kh, w + kw, c), K(f, kh, kw, c)));
|
||||
O(n, h, w, f) = std_addf<kh, kw>(std_mulf(
|
||||
I(n, h + kh, w + kw, c), K(f, kh, kw, c)));
|
||||
}
|
||||
|
||||
ods_def<ConvNCHWOp>:
|
||||
def conv_2d_nchw(I: f32(N, C, H, W), K: f32(F, C, KH, KW)) -> (O: f32(N, F, H, W)) {
|
||||
O(n, f, h, w) = std_addf(O(n, f, h, w),
|
||||
std_mulf(I(n, c, h + kh, w + kw), K(f, c, kh, kw)));
|
||||
O(n, f, h, w) = std_addf<kh, kw>(std_mulf(
|
||||
I(n, c, h + kh, w + kw), K(f, c, kh, kw)));
|
||||
}
|
||||
|
||||
ods_def<ConvDHWOp>:
|
||||
def conv_3d(I: f32(D, H, W), K: f32(KD, KH, KW)) -> (O: f32(D, H, W)) {
|
||||
O(d, h, w) = std_addf(O(d, h, w),
|
||||
std_mulf(I(d + kd, h + kh, w + kw), K(kd, kh, kw)));
|
||||
O(d, h, w) = std_addf<kd, kh, kw>(std_mulf(
|
||||
I(d + kd, h + kh, w + kw), K(kd, kh, kw)));
|
||||
}
|
||||
|
||||
ods_def<ConvNDHWCOp>:
|
||||
def conv_3d_ndhwc(I: f32(N, D, H, W, C), K: f32(F, KD, KH, KW, C)) -> (O: f32(N, D, H, W, F)) {
|
||||
O(n, d, h, w, f) = std_addf(O(n, d, h, w, f),
|
||||
std_mulf(I(n, d + kd, h + kh, w + kw, c), K(f, kd, kh, kw, c)));
|
||||
O(n, d, h, w, f) = std_addf<kd, kh, kw>(std_mulf(
|
||||
I(n, d + kd, h + kh, w + kw, c), K(f, kd, kh, kw, c)));
|
||||
}
|
||||
|
||||
ods_def<ConvNCDHWOp>:
|
||||
def conv_3d_ncdhw(I: f32(N, C, D, H, W), K: f32(F, C, KD, KH, KW)) -> (O: f32(N, F, D, H, W)) {
|
||||
O(n, f, d, h, w) = std_addf(O(n, f, d, h, w),
|
||||
std_mulf(I(n, c, d + kd, h + kh, w + kw), K(f, c, kd, kh, kw)));
|
||||
O(n, f, d, h, w) = std_addf<kd, kh, kw>(std_mulf(
|
||||
I(n, c, d + kd, h + kh, w + kw), K(f, c, kd, kh, kw)));
|
||||
}
|
|
@ -1318,14 +1318,15 @@ func @conv1d_no_symbols(%in : memref<?xf32>, %filter : memref<?xf32>, %out : mem
|
|||
// CHECKPARALLEL: %[[c1:.*]] = constant 1 : index
|
||||
// CHECKPARALLEL: %[[dim0:.*]] = dim %[[arg1]], %[[c0]] : memref<?xf32>
|
||||
// CHECKPARALLEL: %[[dim1:.*]] = dim %[[arg2]], %[[c0]] : memref<?xf32>
|
||||
// CHECKPARALLEL: scf.parallel (%[[b:.*]], %[[m:.*]]) = (%[[c0]], %[[c0]]) to (%[[dim1]], %[[dim0]]) step (%[[c1]], %[[c1]]) {
|
||||
// CHECKPARALLEL: %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[b]], %[[m]])
|
||||
// CHECKPARALLEL: %[[vb:.*]] = load %[[arg0]][%[[aff]]] : memref<?xf32>
|
||||
// CHECKPARALLEL: %[[va:.*]] = load %[[arg1]][%[[m]]] : memref<?xf32>
|
||||
// CHECKPARALLEL: %[[vc:.*]] = load %[[arg2]][%[[b]]] : memref<?xf32>
|
||||
// CHECKPARALLEL: %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
|
||||
// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
|
||||
// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[b]]] : memref<?xf32>
|
||||
// CHECKPARALLEL: scf.parallel (%[[b:.*]]) = (%[[c0]]) to (%[[dim1]]) step (%[[c1]]) {
|
||||
// CHECKPARALLEL: scf.for %[[m:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] {
|
||||
// CHECKPARALLEL: %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[b]], %[[m]])
|
||||
// CHECKPARALLEL: %[[vb:.*]] = load %[[arg0]][%[[aff]]] : memref<?xf32>
|
||||
// CHECKPARALLEL: %[[va:.*]] = load %[[arg1]][%[[m]]] : memref<?xf32>
|
||||
// CHECKPARALLEL: %[[vc:.*]] = load %[[arg2]][%[[b]]] : memref<?xf32>
|
||||
// CHECKPARALLEL: %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
|
||||
// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
|
||||
// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[b]]] : memref<?xf32>
|
||||
|
||||
|
||||
func @conv2d_no_symbols(%in : memref<?x?xf32>, %filter : memref<?x?xf32>, %out : memref<?x?xf32>) -> () {
|
||||
|
@ -1367,15 +1368,17 @@ func @conv2d_no_symbols(%in : memref<?x?xf32>, %filter : memref<?x?xf32>, %out :
|
|||
// CHECKPARALLEL: %[[dim1:.*]] = dim %[[arg1]], %[[c1]] : memref<?x?xf32>
|
||||
// CHECKPARALLEL: %[[dim2:.*]] = dim %[[arg2]], %[[c0]] : memref<?x?xf32>
|
||||
// CHECKPARALLEL: %[[dim3:.*]] = dim %[[arg2]], %[[c1]] : memref<?x?xf32>
|
||||
// CHECKPARALLEL: scf.parallel (%[[arg3:.*]], %[[arg4:.*]], %[[arg5:.*]], %[[arg6:.*]]) = (%[[c0]], %[[c0]], %[[c0]], %[[c0]]) to (%[[dim2]], %[[dim3]], %[[dim0]], %[[dim1]]) step (%[[c1]], %[[c1]], %[[c1]], %[[c1]]) {
|
||||
// CHECKPARALLEL: %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg3]], %[[arg5]])
|
||||
// CHECKPARALLEL: %[[aff2:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg4]], %[[arg6]])
|
||||
// CHECKPARALLEL: %[[vb:.*]] = load %[[arg0]][%[[aff]], %[[aff2]]] : memref<?x?xf32>
|
||||
// CHECKPARALLEL: %[[va:.*]] = load %[[arg1]][%[[arg5]], %[[arg6]]] : memref<?x?xf32>
|
||||
// CHECKPARALLEL: %[[vc:.*]] = load %[[arg2]][%[[arg3]], %[[arg4]]] : memref<?x?xf32>
|
||||
// CHECKPARALLEL: %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
|
||||
// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
|
||||
// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]]] : memref<?x?xf32>
|
||||
// CHECKPARALLEL: scf.parallel (%[[arg3:.*]], %[[arg4:.*]]) = (%[[c0]], %[[c0]]) to (%[[dim2]], %[[dim3]]) step (%[[c1]], %[[c1]]) {
|
||||
// CHECKPARALLEL: scf.for %[[arg5:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] {
|
||||
// CHECKPARALLEL: scf.for %[[arg6:.*]] = %[[c0]] to %[[dim1]] step %[[c1]] {
|
||||
// CHECKPARALLEL: %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg3]], %[[arg5]])
|
||||
// CHECKPARALLEL: %[[aff2:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg4]], %[[arg6]])
|
||||
// CHECKPARALLEL: %[[vb:.*]] = load %[[arg0]][%[[aff]], %[[aff2]]] : memref<?x?xf32>
|
||||
// CHECKPARALLEL: %[[va:.*]] = load %[[arg1]][%[[arg5]], %[[arg6]]] : memref<?x?xf32>
|
||||
// CHECKPARALLEL: %[[vc:.*]] = load %[[arg2]][%[[arg3]], %[[arg4]]] : memref<?x?xf32>
|
||||
// CHECKPARALLEL: %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
|
||||
// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
|
||||
// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]]] : memref<?x?xf32>
|
||||
|
||||
|
||||
func @conv3d_no_symbols(%in : memref<?x?x?xf32>, %filter : memref<?x?x?xf32>, %out : memref<?x?x?xf32>) -> () {
|
||||
|
@ -1427,13 +1430,16 @@ func @conv3d_no_symbols(%in : memref<?x?x?xf32>, %filter : memref<?x?x?xf32>, %o
|
|||
// CHECKPARALLEL: %[[dim3:.*]] = dim %[[arg2]], %[[c0]] : memref<?x?x?xf32>
|
||||
// CHECKPARALLEL: %[[dim4:.*]] = dim %[[arg2]], %[[c1]] : memref<?x?x?xf32>
|
||||
// CHECKPARALLEL: %[[dim5:.*]] = dim %[[arg2]], %[[c2]] : memref<?x?x?xf32>
|
||||
// CHECKPARALLEL: scf.parallel (%[[arg3:.*]], %[[arg4:.*]], %[[arg5:.*]], %[[arg6:.*]], %[[arg7:.*]], %[[arg8:.*]]) = (%[[c0]], %[[c0]], %[[c0]], %[[c0]], %[[c0]], %[[c0]]) to (%[[dim3]], %[[dim4]], %[[dim5]], %[[dim0]], %[[dim1]], %[[dim2]]) step (%[[c1]], %[[c1]], %[[c1]], %[[c1]], %[[c1]], %[[c1]]) {
|
||||
// CHECKPARALLEL: %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg3]], %[[arg6]])
|
||||
// CHECKPARALLEL: %[[aff2:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg4]], %[[arg7]])
|
||||
// CHECKPARALLEL: %[[aff3:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg5]], %[[arg8]])
|
||||
// CHECKPARALLEL: %[[vb:.*]] = load %[[arg0]][%[[aff]], %[[aff2]], %[[aff3]]] : memref<?x?x?xf32>
|
||||
// CHECKPARALLEL: %[[va:.*]] = load %[[arg1]][%[[arg6]], %[[arg7]], %[[arg8]]] : memref<?x?x?xf32>
|
||||
// CHECKPARALLEL: %[[vc:.*]] = load %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref<?x?x?xf32>
|
||||
// CHECKPARALLEL: %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
|
||||
// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
|
||||
// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref<?x?x?xf32>
|
||||
// CHECKPARALLEL: scf.parallel (%[[arg3:.*]], %[[arg4:.*]], %[[arg5:.*]]) = (%[[c0]], %[[c0]], %[[c0]]) to (%[[dim3]], %[[dim4]], %[[dim5]]) step (%[[c1]], %[[c1]], %[[c1]]) {
|
||||
// CHECKPARALLEL: scf.for %[[arg6:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] {
|
||||
// CHECKPARALLEL: scf.for %[[arg7:.*]] = %[[c0]] to %[[dim1]] step %[[c1]] {
|
||||
// CHECKPARALLEL: scf.for %[[arg8:.*]] = %[[c0]] to %[[dim2]] step %[[c1]] {
|
||||
// CHECKPARALLEL: %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg3]], %[[arg6]])
|
||||
// CHECKPARALLEL: %[[aff2:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg4]], %[[arg7]])
|
||||
// CHECKPARALLEL: %[[aff3:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg5]], %[[arg8]])
|
||||
// CHECKPARALLEL: %[[vb:.*]] = load %[[arg0]][%[[aff]], %[[aff2]], %[[aff3]]] : memref<?x?x?xf32>
|
||||
// CHECKPARALLEL: %[[va:.*]] = load %[[arg1]][%[[arg6]], %[[arg7]], %[[arg8]]] : memref<?x?x?xf32>
|
||||
// CHECKPARALLEL: %[[vc:.*]] = load %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref<?x?x?xf32>
|
||||
// CHECKPARALLEL: %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
|
||||
// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
|
||||
// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref<?x?x?xf32>
|
||||
|
|
Loading…
Reference in New Issue