[mlir][Linalg] Conv {1,2,3}D ops defined with TC syntax

Replaced definition of named ND ConvOps with tensor comprehension
syntax which reduces boilerplate code significantly. Furthermore,
new ops to support TF convolutions added (without strides and dilations).

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D84628
This commit is contained in:
Jakub Lichman 2020-07-31 13:18:11 +02:00 committed by Alex Zinenko
parent 03116a9f8c
commit eef1bfb2d2
8 changed files with 146 additions and 285 deletions

View File

@ -17,3 +17,55 @@ ods_def<BatchMatmulOp>:
def batch_matmul(A: f32(Batch, M, K), B: f32(Batch, K, N)) -> (C: f32(Batch, M, N)) {
C(b, m, n) = std_addf<k>(std_mulf(A(b, m, k), B(b, k, n)));
}
ods_def<ConvWOp>:
def conv_1d(I: f32(W), K: f32(KW)) -> (O: f32(W)) {
O(w) = std_addf(O(w), std_mulf(I(w + kw), K(kw)));
}
ods_def<ConvNWCOp>:
def conv_1d_nwc(I: f32(N, W, C), K: f32(F, KW, C)) -> (O: f32(N, W, F)) {
O(n, w, f) = std_addf(O(n, w, f),
std_mulf(I(n, w + kw, c), K(f, kw, c)));
}
ods_def<ConvNCWOp>:
def conv_1d_ncw(I: f32(N, C, W), K: f32(F, C, KW)) -> (O: f32(N, F, W)) {
O(n, f, w) = std_addf(O(n, f, w),
std_mulf(I(n, c, w + kw), K(f, c, kw)));
}
ods_def<ConvHWOp>:
def conv_2d(I: f32(H, W), K: f32(KH, KW)) -> (O: f32(H, W)) {
O(h, w) = std_addf(O(h, w), std_mulf(I(h + kh, w + kw), K(kh, kw)));
}
ods_def<ConvNHWCOp>:
def conv_2d_nhwc(I: f32(N, H, W, C), K: f32(F, KH, KW, C)) -> (O: f32(N, H, W, F)) {
O(n, h, w, f) = std_addf(O(n, h, w, f),
std_mulf(I(n, h + kh, w + kw, c), K(f, kh, kw, c)));
}
ods_def<ConvNCHWOp>:
def conv_2d_nchw(I: f32(N, C, H, W), K: f32(F, C, KH, KW)) -> (O: f32(N, F, H, W)) {
O(n, f, h, w) = std_addf(O(n, f, h, w),
std_mulf(I(n, c, h + kh, w + kw), K(f, c, kh, kw)));
}
ods_def<ConvDHWOp>:
def conv_3d(I: f32(D, H, W), K: f32(KD, KH, KW)) -> (O: f32(D, H, W)) {
O(d, h, w) = std_addf(O(d, h, w),
std_mulf(I(d + kd, h + kh, w + kw), K(kd, kh, kw)));
}
ods_def<ConvNDHWCOp>:
def conv_3d_ndhwc(I: f32(N, D, H, W, C), K: f32(F, KD, KH, KW, C)) -> (O: f32(N, D, H, W, F)) {
O(n, d, h, w, f) = std_addf(O(n, d, h, w, f),
std_mulf(I(n, d + kd, h + kh, w + kw, c), K(f, kd, kh, kw, c)));
}
ods_def<ConvNCDHWOp>:
def conv_3d_ncdhw(I: f32(N, C, D, H, W), K: f32(F, C, KD, KH, KW)) -> (O: f32(N, F, D, H, W)) {
O(n, f, d, h, w) = std_addf(O(n, f, d, h, w),
std_mulf(I(n, c, d + kd, h + kh, w + kw), K(f, c, kd, kh, kw)));
}

View File

@ -85,14 +85,6 @@ AffineMap extractOrIdentityMap(Optional<AffineMap> maybeMap, unsigned rank,
SmallVector<AffineExpr, 4> concat(ArrayRef<AffineExpr> a,
ArrayRef<AffineExpr> b);
/// Generates indexing maps for convolution with the following structure:
/// input: (m_1, ..., m_r, n_1, ..., n_r) -> (m_1 + n_1, ..., m_r + n_r)
/// kernel: (m_1, ..., m_r, n_1, ..., n_r) -> (n_1, ..., n_r)
/// output: (m_1, ..., m_r, n_1, ..., n_r) -> (m_1, ..., m_r)
/// where r is the rank of the input, kernel and output
llvm::Optional<SmallVector<AffineMap, 8>>
createConvNDIndexingMaps(MLIRContext *context, unsigned rank);
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterfaces.h.inc"
#define GET_OP_CLASSES

View File

@ -180,131 +180,6 @@ def FillOp : LinalgStructured_Op<"fill", [NInputs<0>, NOutputs<1>]> {
let hasFolder = 1;
}
class ConvOpBase<string mnemonic, int N>
: LinalgStructured_Op<mnemonic, [NInputs<2>, NOutputs<1>]> {
let description = [{
Base operation for any N-D Convolution implemented as a linalg.generic op.
Usage:
```mlir
linalg.conv<N>D(%in, %filter, %out) : memref<(?x)+f32>,
memref<(?x)+f32>,
memref<(?x)+f32>
```
where %in: input array
%filter: kernel or filter that will be applied on the input array
%out: output array
and rank of the operands is *N*.
Every child convolution is expressed as:
```mlir
#conv_trait = {
args_in = 2,
args_out = 1,
indexing_maps = #conv_accesses,
library_call = "linalg_conv",
iterator_types = [("parallel", "parallel")+], // `2 * rank` iterators
}
linalg.generic #conv_trait %in, %filter, %out {
^bb0(%a: f32, %b: f32, %c: f32) :
%d = mulf %a, %b : f32
%e = addf %c, %d : f32
linalg.yield %e : f32
} : memref<(?x)+f32>,
memref<(?x)+f32>,
memref<(?x)+f32>
```
where #conv_accesses depend on the rank of the operands and thus
can be found in the documentation of each N-D case.
Please note that the input array is expected to be right-padded i.e.
the size of the input is greater than or equal to the size of the output
+ size of the kernel - 1. If it is not padded the behavior of the op
is undefined.
}];
let arguments = (ins AnyStridedMemRefOfRank<N>,
AnyStridedMemRefOfRank<N>,
AnyStridedMemRefOfRank<N>);
let extraClassDeclaration = libraryCallName # [{
llvm::Optional<SmallVector<StringRef, 8>> referenceIterators() {
// There are always 2 loops for each dimension of the convolution. First
// iterates output and second kernel. Since ranks of all 3 operands must
// be the same it does not matter which operand is picked to get the rank.
// Loops iterating the output can be parallelized and thus are marked as
// "parallel" while loops iterating the kernel are accumulating the
// products and therefore are marked as "reduction".
unsigned rank = getInputShapedType(0).getRank();
SmallVector<StringRef, 8> parallel(rank, getParallelIteratorTypeName());
SmallVector<StringRef, 8> reduction(rank, getReductionIteratorTypeName());
parallel.insert(parallel.end(), reduction.begin(), reduction.end());
return parallel;
}
// Generates indexing maps with the following structure:
// input: (m_1, ..., m_r, n_1, ..., n_r) -> (m_1 + n_1, ..., m_r + n_r)
// kernel: (m_1, ..., m_r, n_1, ..., n_r) -> (n_1, ..., n_r)
// output: (m_1, ..., m_r, n_1, ..., n_r) -> (m_1, ..., m_r)
// where r is the rank of the input, kernel and output
llvm::Optional<SmallVector<AffineMap, 8>> referenceIndexingMaps() {
MLIRContext *context = getContext();
unsigned rank = getInputShapedType(0).getRank();
return createConvNDIndexingMaps(context, rank);
}
}];
let hasFolder = 1;
let verifier = [{ return ::verify(*this); }];
}
def Conv1DOp : ConvOpBase<"conv1D", 1> {
let description = [{
*1D* convolution which uses following affine maps to access operands:
```mlir
#conv_accesses = [
affine_map<(m, n) -> (m + n)>, // in
affine_map<(m, n) -> (n)>, // kernel
affine_map<(m, n) -> (m)> // out
]
```
}];
}
def Conv2DOp : ConvOpBase<"conv2D", 2> {
let description = [{
*2D* convolution which uses following affine maps to access operands:
```mlir
#conv_accesses = [
affine_map<(m1, m2, n1, n2) -> (m1 + n1, m2 + n2)>, // in
affine_map<(m1, m2, n1, n2) -> (n1, n2)>, // kernel
affine_map<(m1, m2, n1, n2) -> (m1, m2) // out
]
```
}];
}
def Conv3DOp : ConvOpBase<"conv3D", 3> {
let description = [{
*3D* convolution which uses following affine maps to access operands:
```mlir
#conv_accesses = [
affine_map<(m1, m2, m3, n1, n2, n3) -> (m1 + n1, m2 + n2, m3 + n3)>, // in
affine_map<(m1, m2, m3, n1, n2, n3) -> (n1, n2, n3)>, // kernel
affine_map<(m1, m2, m3, n1, n2, n3) -> (m1, m2, m3)> // out
]
```
}];
}
/// A base class for pooling operation such as conv. The arguments must contain
/// optional arguments `strides`, `dilations` and `padding` with following type:
/// OptionalAttr<I64ArrayAttr>:$strides

View File

@ -236,9 +236,6 @@ void mlir::populateLinalgToStandardConversionPatterns(
LinalgOpConversion<PoolingMinOp>,
LinalgOpConversion<PoolingSumOp>,
LinalgOpConversion<CopyOp>,
LinalgOpConversion<Conv1DOp>,
LinalgOpConversion<Conv2DOp>,
LinalgOpConversion<Conv3DOp>,
LinalgOpConversion<FillOp>,
LinalgOpConversion<GenericOp>,
LinalgOpConversion<IndexedGenericOp>>(ctx);

View File

@ -986,17 +986,6 @@ static LogicalResult verifyStrideOrDilation(LinalgPoolingOp op,
return success();
}
template <typename ConvNDOp>
static LogicalResult verify(ConvNDOp op) {
auto outputType = op.getOutputShapedType(0).getElementType();
auto inputType = op.getInputShapedType(0).getElementType();
auto kernelType = op.getInputShapedType(1).getElementType();
if (outputType != inputType || inputType != kernelType)
return op.emitOpError("expected all element types of operands to match");
return success();
}
static LogicalResult verify(ConvOp op) {
auto oType = op.output().getType().cast<MemRefType>();
auto fType = op.filter().getType().cast<MemRefType>();
@ -1107,27 +1096,6 @@ mlir::linalg::weightedPoolingInputIndex(PoolingOp op,
return res;
}
llvm::Optional<SmallVector<AffineMap, 8>>
mlir::linalg::createConvNDIndexingMaps(MLIRContext *context, unsigned rank) {
unsigned numDims = rank * 2, idx = 0;
SmallVector<AffineExpr, 8> dims, in, kernel, out;
dims = makeAffineDimExprs(numDims, idx, context);
in.reserve(rank);
kernel.reserve(rank);
out.reserve(rank);
for (unsigned i = 0; i < rank; i++) {
in.push_back(dims[i] + dims[rank + i]);
kernel.push_back(dims[rank + i]);
out.push_back(dims[i]);
}
return SmallVector<AffineMap, 8>{AffineMap::get(numDims, 0, in, context),
AffineMap::get(numDims, 0, kernel, context),
AffineMap::get(numDims, 0, out, context)};
}
#define INSTANTIATE_WEIGHTED_POOLING_INPUT_INDEX(OP_TYPE) \
template SmallVector<AffineExpr, 4> \
mlir::linalg::weightedPoolingInputIndex<OP_TYPE>( \
@ -1209,18 +1177,6 @@ LogicalResult FillOp::fold(ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &) {
return foldMemRefCast(*this);
}
LogicalResult Conv1DOp::fold(ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &) {
return foldMemRefCast(*this);
}
LogicalResult Conv2DOp::fold(ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &) {
return foldMemRefCast(*this);
}
LogicalResult Conv3DOp::fold(ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &) {
return foldMemRefCast(*this);
}
LogicalResult GenericOp::fold(ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &) {
return foldMemRefCast(*this);
@ -1362,3 +1318,39 @@ LogicalResult MatvecOp::fold(ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &) {
return foldMemRefCast(*this);
}
LogicalResult ConvWOp::fold(ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &) {
return foldMemRefCast(*this);
}
LogicalResult ConvNWCOp::fold(ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &) {
return foldMemRefCast(*this);
}
LogicalResult ConvNCWOp::fold(ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &) {
return foldMemRefCast(*this);
}
LogicalResult ConvHWOp::fold(ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &) {
return foldMemRefCast(*this);
}
LogicalResult ConvNHWCOp::fold(ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &) {
return foldMemRefCast(*this);
}
LogicalResult ConvNCHWOp::fold(ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &) {
return foldMemRefCast(*this);
}
LogicalResult ConvDHWOp::fold(ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &) {
return foldMemRefCast(*this);
}
LogicalResult ConvNDHWCOp::fold(ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &) {
return foldMemRefCast(*this);
}
LogicalResult ConvNCDHWOp::fold(ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &) {
return foldMemRefCast(*this);
}

View File

@ -295,61 +295,6 @@ void emitScalarImplementation(ArrayRef<Value> allIvs, FillOp fillOp) {
nPar > 0 ? O(ivs) = fillOp.value() : O() = fillOp.value();
}
/// Following functions emit scalar part of the N-D convolution op.
/// N-D convolution has 2N loops:
/// 1-N: Iterate over the output array *O* with iterators *m1, ..., mN*.
/// N-2N:. Iterate over the kernel *K* with iterators *n1, ..., nN*.
///
/// The scalar part accumulates products of input array *I* values with kernel
/// ones. The accumulation expression therefore looks like:
/// O[m1, ..., mN] += I[m1 + n1, ..., mN + nN] * K[n1, ..., nN].
/// Note that the input array has to be padded in order to prevent
/// out of bounds accesses.
template <typename IndexedValueType>
void emitScalarImplementation(ArrayRef<Value> allIvs, Conv1DOp convOp) {
assert(convOp.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
assert(allIvs.size() == 2);
Value m1(allIvs[0]);
Value n1(allIvs[1]);
IndexedValueType I(convOp.getInput(0)), K(convOp.getInput(1)),
O(convOp.getOutputBuffer(0));
// Emit scalar form for the 1D conv case.
Value i1 = m1 + n1;
O(m1) = O(m1) + I(i1) * K(n1);
}
template <typename IndexedValueType>
void emitScalarImplementation(ArrayRef<Value> allIvs, Conv2DOp convOp) {
assert(convOp.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
assert(allIvs.size() == 4);
Value m1(allIvs[0]), m2(allIvs[1]);
Value n1(allIvs[2]), n2(allIvs[3]);
IndexedValueType I(convOp.getInput(0)), K(convOp.getInput(1)),
O(convOp.getOutputBuffer(0));
// Emit scalar form for the 2D conv case.
Value i1 = m1 + n1;
Value i2 = m2 + n2;
O(m1, m2) = O(m1, m2) + I(i1, i2) * K(n1, n2);
}
template <typename IndexedValueType>
void emitScalarImplementation(ArrayRef<Value> allIvs, Conv3DOp convOp) {
assert(convOp.hasBufferSemantics() &&
"expected linalg op with buffer semantics");
assert(allIvs.size() == 6);
Value m1(allIvs[0]), m2(allIvs[1]), m3(allIvs[2]);
Value n1(allIvs[3]), n2(allIvs[4]), n3(allIvs[5]);
IndexedValueType I(convOp.getInput(0)), K(convOp.getInput(1)),
O(convOp.getOutputBuffer(0));
// Emit scalar form for the 3D conv case.
Value i1 = m1 + n1;
Value i2 = m2 + n2;
Value i3 = m3 + n3;
O(m1, m2, m3) = O(m1, m2, m3) + I(i1, i2, i3) * K(n1, n2, n3);
}
template <typename IndexedValueType>
Value getConvOpInput(ConvOp convOp, StdIndexedValue im,
MutableArrayRef<Value> imIdx) {
@ -738,6 +683,24 @@ static Optional<LinalgLoops> linalgOpToLoopsImplSwitch(Operation *op,
return linalgOpToLoopsImpl<LoopTy, DotOp>(op, builder);
if (isa<BatchMatmulOp>(op))
return linalgOpToLoopsImpl<LoopTy, BatchMatmulOp>(op, builder);
if (isa<ConvWOp>(op))
return linalgOpToLoopsImpl<LoopTy, ConvWOp>(op, builder);
if (isa<ConvNWCOp>(op))
return linalgOpToLoopsImpl<LoopTy, ConvNWCOp>(op, builder);
if (isa<ConvNCWOp>(op))
return linalgOpToLoopsImpl<LoopTy, ConvNCWOp>(op, builder);
if (isa<ConvHWOp>(op))
return linalgOpToLoopsImpl<LoopTy, ConvHWOp>(op, builder);
if (isa<ConvNHWCOp>(op))
return linalgOpToLoopsImpl<LoopTy, ConvNHWCOp>(op, builder);
if (isa<ConvNCHWOp>(op))
return linalgOpToLoopsImpl<LoopTy, ConvNCHWOp>(op, builder);
if (isa<ConvDHWOp>(op))
return linalgOpToLoopsImpl<LoopTy, ConvDHWOp>(op, builder);
if (isa<ConvNDHWCOp>(op))
return linalgOpToLoopsImpl<LoopTy, ConvNDHWCOp>(op, builder);
if (isa<ConvNCDHWOp>(op))
return linalgOpToLoopsImpl<LoopTy, ConvNCDHWOp>(op, builder);
llvm_unreachable("Unexpected op in linalgOpToLoopsImpl");
}

View File

@ -507,11 +507,3 @@ func @named_ops(%a3: memref<?x?x?xf32>, %b3: memref<?x?xf32>, %c3: memref<?x?x?x
linalg.batch_matmul %a3, %b3, %c3 : (memref<?x?x?xf32>, memref<?x?xf32>, memref<?x?x?xf32>) -> ()
return
}
// -----
func @conv_type_mismatch(%in: memref<?xi32>, %filter: memref<?xf32>, %out: memref<?xf32>) {
// expected-error @+1 {{expected all element types of operands to match}}
linalg.conv1D(%in, %filter, %out) : memref<?xi32>, memref<?xf32>, memref<?xf32>
return
}

View File

@ -1288,7 +1288,7 @@ func @conv4d(%in : memref<?x?x?x?xf32>, %filter : memref<?x?x?x?xf32>, %out : m
// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[i0]], %[[i1]], %[[i2]], %[[i3]]] : memref<?x?x?x?xf32>
func @conv1d_no_symbols(%in : memref<?xf32>, %filter : memref<?xf32>, %out : memref<?xf32>) -> () {
linalg.conv1D(%in, %filter, %out) : memref<?xf32>, memref<?xf32>, memref<?xf32>
linalg.conv_1d %in, %filter, %out : (memref<?xf32>, memref<?xf32>, memref<?xf32>)
return
}
@ -1303,10 +1303,10 @@ func @conv1d_no_symbols(%in : memref<?xf32>, %filter : memref<?xf32>, %out : mem
// CHECKLOOP: scf.for %[[b:.*]] = %[[c0]] to %[[dim1]] step %[[c1]] {
// CHECKLOOP: scf.for %[[m:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] {
// CHECKLOOP: %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[b]], %[[m]])
// CHECKLOOP: %[[va:.*]] = load %[[arg1]][%[[m]]] : memref<?xf32>
// CHECKLOOP: %[[vb:.*]] = load %[[arg0]][%[[aff]]] : memref<?xf32>
// CHECKLOOP: %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
// CHECKLOOP: %[[va:.*]] = load %[[arg1]][%[[m]]] : memref<?xf32>
// CHECKLOOP: %[[vc:.*]] = load %[[arg2]][%[[b]]] : memref<?xf32>
// CHECKLOOP: %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
// CHECKLOOP: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
// CHECKLOOP: store %[[res]], %[[arg2]][%[[b]]] : memref<?xf32>
@ -1318,19 +1318,18 @@ func @conv1d_no_symbols(%in : memref<?xf32>, %filter : memref<?xf32>, %out : mem
// CHECKPARALLEL: %[[c1:.*]] = constant 1 : index
// CHECKPARALLEL: %[[dim0:.*]] = dim %[[arg1]], %[[c0]] : memref<?xf32>
// CHECKPARALLEL: %[[dim1:.*]] = dim %[[arg2]], %[[c0]] : memref<?xf32>
// CHECKPARALLEL: scf.parallel (%[[b:.*]]) = (%[[c0]]) to (%[[dim1]]) step (%[[c1]]) {
// CHECKPARALLEL: scf.for %[[m:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] {
// CHECKPARALLEL: %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[b]], %[[m]])
// CHECKPARALLEL: %[[va:.*]] = load %[[arg1]][%[[m]]] : memref<?xf32>
// CHECKPARALLEL: %[[vb:.*]] = load %[[arg0]][%[[aff]]] : memref<?xf32>
// CHECKPARALLEL: %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
// CHECKPARALLEL: %[[vc:.*]] = load %[[arg2]][%[[b]]] : memref<?xf32>
// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[b]]] : memref<?xf32>
// CHECKPARALLEL: scf.parallel (%[[b:.*]], %[[m:.*]]) = (%[[c0]], %[[c0]]) to (%[[dim1]], %[[dim0]]) step (%[[c1]], %[[c1]]) {
// CHECKPARALLEL: %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[b]], %[[m]])
// CHECKPARALLEL: %[[vb:.*]] = load %[[arg0]][%[[aff]]] : memref<?xf32>
// CHECKPARALLEL: %[[va:.*]] = load %[[arg1]][%[[m]]] : memref<?xf32>
// CHECKPARALLEL: %[[vc:.*]] = load %[[arg2]][%[[b]]] : memref<?xf32>
// CHECKPARALLEL: %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[b]]] : memref<?xf32>
func @conv2d_no_symbols(%in : memref<?x?xf32>, %filter : memref<?x?xf32>, %out : memref<?x?xf32>) -> () {
linalg.conv2D(%in, %filter, %out) : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
linalg.conv_2d %in, %filter, %out : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
return
}
// CHECKLOOP-LABEL: @conv2d_no_symbols
@ -1349,10 +1348,12 @@ func @conv2d_no_symbols(%in : memref<?x?xf32>, %filter : memref<?x?xf32>, %out :
// CHECKLOOP: scf.for %[[arg6:.*]] = %[[c0]] to %[[dim1]] step %[[c1]] {
// CHECKLOOP: %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg3]], %[[arg5]])
// CHECKLOOP: %[[aff2:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg4]], %[[arg6]])
// CHECKLOOP: %[[va:.*]] = load %[[arg1]][%[[arg5]], %[[arg6]]] : memref<?x?xf32>
// CHECKLOOP: %[[vb:.*]] = load %[[arg0]][%[[aff]], %[[aff2]]] : memref<?x?xf32>
// CHECKLOOP: %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
// CHECKLOOP: %[[va:.*]] = load %[[arg1]][%[[arg5]], %[[arg6]]] : memref<?x?xf32>
// CHECKLOOP: %[[vc:.*]] = load %[[arg2]][%[[arg3]], %[[arg4]]] : memref<?x?xf32>
// CHECKLOOP: %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
// CHECKLOOP: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
// CHECKLOOP: store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]]] : memref<?x?xf32>
@ -1366,21 +1367,19 @@ func @conv2d_no_symbols(%in : memref<?x?xf32>, %filter : memref<?x?xf32>, %out :
// CHECKPARALLEL: %[[dim1:.*]] = dim %[[arg1]], %[[c1]] : memref<?x?xf32>
// CHECKPARALLEL: %[[dim2:.*]] = dim %[[arg2]], %[[c0]] : memref<?x?xf32>
// CHECKPARALLEL: %[[dim3:.*]] = dim %[[arg2]], %[[c1]] : memref<?x?xf32>
// CHECKPARALLEL: scf.parallel (%[[arg3:.*]], %[[arg4:.*]]) = (%[[c0]], %[[c0]]) to (%[[dim2]], %[[dim3]]) step (%[[c1]], %[[c1]]) {
// CHECKPARALLEL: scf.for %[[arg5:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] {
// CHECKPARALLEL: scf.for %[[arg6:.*]] = %[[c0]] to %[[dim1]] step %[[c1]] {
// CHECKPARALLEL: %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg3]], %[[arg5]])
// CHECKPARALLEL: %[[aff2:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg4]], %[[arg6]])
// CHECKPARALLEL: %[[va:.*]] = load %[[arg1]][%[[arg5]], %[[arg6]]] : memref<?x?xf32>
// CHECKPARALLEL: %[[vb:.*]] = load %[[arg0]][%[[aff]], %[[aff2]]] : memref<?x?xf32>
// CHECKPARALLEL: %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
// CHECKPARALLEL: %[[vc:.*]] = load %[[arg2]][%[[arg3]], %[[arg4]]] : memref<?x?xf32>
// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]]] : memref<?x?xf32>
// CHECKPARALLEL: scf.parallel (%[[arg3:.*]], %[[arg4:.*]], %[[arg5:.*]], %[[arg6:.*]]) = (%[[c0]], %[[c0]], %[[c0]], %[[c0]]) to (%[[dim2]], %[[dim3]], %[[dim0]], %[[dim1]]) step (%[[c1]], %[[c1]], %[[c1]], %[[c1]]) {
// CHECKPARALLEL: %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg3]], %[[arg5]])
// CHECKPARALLEL: %[[aff2:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg4]], %[[arg6]])
// CHECKPARALLEL: %[[vb:.*]] = load %[[arg0]][%[[aff]], %[[aff2]]] : memref<?x?xf32>
// CHECKPARALLEL: %[[va:.*]] = load %[[arg1]][%[[arg5]], %[[arg6]]] : memref<?x?xf32>
// CHECKPARALLEL: %[[vc:.*]] = load %[[arg2]][%[[arg3]], %[[arg4]]] : memref<?x?xf32>
// CHECKPARALLEL: %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]]] : memref<?x?xf32>
func @conv3d_no_symbols(%in : memref<?x?x?xf32>, %filter : memref<?x?x?xf32>, %out : memref<?x?x?xf32>) -> () {
linalg.conv3D(%in, %filter, %out) : memref<?x?x?xf32>, memref<?x?x?xf32>, memref<?x?x?xf32>
linalg.conv_3d %in, %filter, %out : (memref<?x?x?xf32>, memref<?x?x?xf32>, memref<?x?x?xf32>)
return
}
@ -1406,10 +1405,12 @@ func @conv3d_no_symbols(%in : memref<?x?x?xf32>, %filter : memref<?x?x?xf32>, %o
// CHECKLOOP: %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg3]], %[[arg6]])
// CHECKLOOP: %[[aff2:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg4]], %[[arg7]])
// CHECKLOOP: %[[aff3:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg5]], %[[arg8]])
// CHECKLOOP: %[[va:.*]] = load %[[arg1]][%[[arg6]], %[[arg7]], %[[arg8]]] : memref<?x?x?xf32>
// CHECKLOOP: %[[vb:.*]] = load %[[arg0]][%[[aff]], %[[aff2]], %[[aff3]]] : memref<?x?x?xf32>
// CHECKLOOP: %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
// CHECKLOOP: %[[va:.*]] = load %[[arg1]][%[[arg6]], %[[arg7]], %[[arg8]]] : memref<?x?x?xf32>
// CHECKLOOP: %[[vc:.*]] = load %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref<?x?x?xf32>
// CHECKLOOP: %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
// CHECKLOOP: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
// CHECKLOOP: store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref<?x?x?xf32>
@ -1426,16 +1427,13 @@ func @conv3d_no_symbols(%in : memref<?x?x?xf32>, %filter : memref<?x?x?xf32>, %o
// CHECKPARALLEL: %[[dim3:.*]] = dim %[[arg2]], %[[c0]] : memref<?x?x?xf32>
// CHECKPARALLEL: %[[dim4:.*]] = dim %[[arg2]], %[[c1]] : memref<?x?x?xf32>
// CHECKPARALLEL: %[[dim5:.*]] = dim %[[arg2]], %[[c2]] : memref<?x?x?xf32>
// CHECKPARALLEL: scf.parallel (%[[arg3:.*]], %[[arg4:.*]], %[[arg5:.*]]) = (%[[c0]], %[[c0]], %[[c0]]) to (%[[dim3]], %[[dim4]], %[[dim5]]) step (%[[c1]], %[[c1]], %[[c1]]) {
// CHECKPARALLEL: scf.for %[[arg6:.*]] = %[[c0]] to %[[dim0]] step %[[c1]] {
// CHECKPARALLEL: scf.for %[[arg7:.*]] = %[[c0]] to %[[dim1]] step %[[c1]] {
// CHECKPARALLEL: scf.for %[[arg8:.*]] = %[[c0]] to %[[dim2]] step %[[c1]] {
// CHECKPARALLEL: %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg3]], %[[arg6]])
// CHECKPARALLEL: %[[aff2:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg4]], %[[arg7]])
// CHECKPARALLEL: %[[aff3:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg5]], %[[arg8]])
// CHECKPARALLEL: %[[va:.*]] = load %[[arg1]][%[[arg6]], %[[arg7]], %[[arg8]]] : memref<?x?x?xf32>
// CHECKPARALLEL: %[[vb:.*]] = load %[[arg0]][%[[aff]], %[[aff2]], %[[aff3]]] : memref<?x?x?xf32>
// CHECKPARALLEL: %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
// CHECKPARALLEL: %[[vc:.*]] = load %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref<?x?x?xf32>
// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref<?x?x?xf32>
// CHECKPARALLEL: scf.parallel (%[[arg3:.*]], %[[arg4:.*]], %[[arg5:.*]], %[[arg6:.*]], %[[arg7:.*]], %[[arg8:.*]]) = (%[[c0]], %[[c0]], %[[c0]], %[[c0]], %[[c0]], %[[c0]]) to (%[[dim3]], %[[dim4]], %[[dim5]], %[[dim0]], %[[dim1]], %[[dim2]]) step (%[[c1]], %[[c1]], %[[c1]], %[[c1]], %[[c1]], %[[c1]]) {
// CHECKPARALLEL: %[[aff:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg3]], %[[arg6]])
// CHECKPARALLEL: %[[aff2:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg4]], %[[arg7]])
// CHECKPARALLEL: %[[aff3:.*]] = affine.apply #[[$stride1Dilation1]](%[[arg5]], %[[arg8]])
// CHECKPARALLEL: %[[vb:.*]] = load %[[arg0]][%[[aff]], %[[aff2]], %[[aff3]]] : memref<?x?x?xf32>
// CHECKPARALLEL: %[[va:.*]] = load %[[arg1]][%[[arg6]], %[[arg7]], %[[arg8]]] : memref<?x?x?xf32>
// CHECKPARALLEL: %[[vc:.*]] = load %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref<?x?x?xf32>
// CHECKPARALLEL: %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32
// CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32
// CHECKPARALLEL: store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref<?x?x?xf32>