[mlir][linalg] Add support for memref inputs/outputs for `linalg.tiled_loop`.

Also use `ArrayAttr` to pass iterator pass to the TiledLoopOp builder.

Differential Revision: https://reviews.llvm.org/D98871
This commit is contained in:
Alexander Belyaev 2021-03-18 16:04:02 +01:00
parent c539be1dcb
commit 283799157e
4 changed files with 53 additions and 17 deletions

View File

@ -15,6 +15,8 @@
include "mlir/IR/OpBase.td"
def LinalgOperand: AnyTypeOf<[AnyRankedTensor, AnyStridedMemRef]>;
def Linalg_Dialect : Dialect {
let name = "linalg";
let description = [{

View File

@ -496,21 +496,25 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
let summary = "Linalg tiled loop operation";
let description = [{
This is a loop-like operation with additional properties. The arguments
also include the input and the output tensors and the attributes to specify
the iterator types. The body region of the loop contains `subtensor`
operations applied to every tensor argument of TiledLoopOp.
also include the input and the output tensors or memrefs and the attributes
to specify the iterator types.
Parsing TiledLoopOp will set all elements of the `iterator_types` attribute
to "parallel" type, when it is absent from the custom format.
Tensor-based version:
The body region of the loop contains `subtensor` operations applied to
every tensor argument of TiledLoopOp.
The body region must contain exactly one block that terminates with
`linalg.yield` with the operands resulting from `subtensor_insert`
operations.
Parsing TiledLoopOp will set all elements of the `iterator_types` attribute
to "parallel" type, when it is absent from the custom format.
Example:
```mlir
linalg.tiled_loop (%i) = (%c0) to (%c24) step (%c4)
%0 = linalg.tiled_loop (%i) = (%c0) to (%c24) step (%c4)
ins(%lhs, %rhs : tensor<24x64xi8>, tensor<24x64xi8>)
outs(%out : tensor<24x64xi8>)
iterators("parallel") {
@ -528,13 +532,40 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
linalg.yield %result : tensor<24x64xi8>
}
```
MemRef-based version:
The body region of the loop contains `subview` operations applied to
every memref argument of TiledLoopOp.
The body region must contain exactly one block that terminates with
`linalg.yield` with no operands.
Example:
```mlir
linalg.tiled_loop (%i) = (%c0) to (%c24) step (%c4)
ins(%lhs, %rhs : memref<24x64xi8>, memref<24x64xi8>)
outs(%out : memref<24x64xi8>)
iterators("parallel") {
%lhs_sub = subview %lhs[%i, 0] [%c4, %c64] [1, 1]
: memref<24x64xi8> to memref<?x?xi8>
%rhs_sub = subview %rhs[%i, 0] [%c4, %c64] [1, 1]
: memref<24x64xi8> to memref<?x?xi8>
%out_sub = subview %out[%i, 0] [%c4, %c64] [1, 1]
: memref<24x64xi8> to memref<?x?xi8>
%result_sub = linalg.generic ...
linalg.yield
}
```
}];
let arguments = (ins Variadic<Index>:$lowerBound,
Variadic<Index>:$upperBound,
Variadic<Index>:$step,
Variadic<AnyRankedTensor>:$inputs,
Variadic<AnyRankedTensor>:$outputs,
Variadic<LinalgOperand>:$inputs,
Variadic<LinalgOperand>:$outputs,
ArrayAttr:$iterator_types);
let results = (outs Variadic<AnyRankedTensor>:$results);
let regions = (region SizedRegion<1>:$region);
@ -542,7 +573,7 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
let builders = [
OpBuilder<(ins "ValueRange":$lowerBounds, "ValueRange":$upperBounds,
"ValueRange":$steps, "ValueRange":$inputs, "ValueRange":$outputs,
"ArrayRef<StringRef>":$iteratorTypes,
"ArrayAttr":$iteratorTypes,
CArg<"function_ref<void (OpBuilder &, Location, ValueRange)>",
"nullptr">:$bodyBuilderFn)>,
];

View File

@ -496,8 +496,6 @@ def PoolingSumOp: SingleInputPoolingBase_Op<"pooling_sum"> {
//===----------------------------------------------------------------------===//
// Generic Linalg ops.
//===----------------------------------------------------------------------===//
def LinalgOperand: AnyTypeOf<[AnyRankedTensor, AnyStridedMemRef]>;
class LinalgOperandOfRank<int rank>: Type<
And<[
LinalgOperand.predicate,

View File

@ -1744,7 +1744,7 @@ static LogicalResult verify(linalg::YieldOp op) {
void TiledLoopOp::build(
OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
ValueRange upperBounds, ValueRange steps, ValueRange inputs,
ValueRange outputs, ArrayRef<StringRef> iteratorTypes,
ValueRange outputs, ArrayAttr iteratorTypes,
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
result.addOperands(lowerBounds);
result.addOperands(upperBounds);
@ -1758,9 +1758,14 @@ void TiledLoopOp::build(
static_cast<int32_t>(steps.size()),
static_cast<int32_t>(inputs.size()),
static_cast<int32_t>(outputs.size())}));
result.addAttribute(getIteratorTypesAttrName(),
builder.getStrArrayAttr(iteratorTypes));
result.addTypes(outputs.getTypes());
result.addAttribute(getIteratorTypesAttrName(), iteratorTypes);
// Add output types for `RankedTensorType` output arguments.
for (Value output : outputs) {
Type outputType = output.getType();
if (outputType.isa<RankedTensorType>())
result.addTypes(outputType);
}
OpBuilder::InsertionGuard guard(builder);
unsigned numIVs = steps.size();
@ -1771,8 +1776,8 @@ void TiledLoopOp::build(
if (bodyBuilderFn) {
builder.setInsertionPointToStart(bodyBlock);
bodyBuilderFn(builder, result.location, bodyBlock->getArguments());
TiledLoopOp::ensureTerminator(*bodyRegion, builder, result.location);
}
TiledLoopOp::ensureTerminator(*bodyRegion, builder, result.location);
}
static void print(OpAsmPrinter &p, TiledLoopOp op) {