[mlir][linalg] Add support for memref inputs/outputs for `linalg.tiled_loop`.

Also use `ArrayAttr` to pass iterator pass to the TiledLoopOp builder.

Differential Revision: https://reviews.llvm.org/D98871
This commit is contained in:
Alexander Belyaev 2021-03-18 16:04:02 +01:00
parent c539be1dcb
commit 283799157e
4 changed files with 53 additions and 17 deletions

View File

@ -15,6 +15,8 @@
include "mlir/IR/OpBase.td" include "mlir/IR/OpBase.td"
def LinalgOperand: AnyTypeOf<[AnyRankedTensor, AnyStridedMemRef]>;
def Linalg_Dialect : Dialect { def Linalg_Dialect : Dialect {
let name = "linalg"; let name = "linalg";
let description = [{ let description = [{

View File

@ -496,21 +496,25 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
let summary = "Linalg tiled loop operation"; let summary = "Linalg tiled loop operation";
let description = [{ let description = [{
This is a loop-like operation with additional properties. The arguments This is a loop-like operation with additional properties. The arguments
also include the input and the output tensors and the attributes to specify also include the input and the output tensors or memrefs and the attributes
the iterator types. The body region of the loop contains `subtensor` to specify the iterator types.
operations applied to every tensor argument of TiledLoopOp.
Parsing TiledLoopOp will set all elements of the `iterator_types` attribute
to "parallel" type, when it is absent from the custom format.
Tensor-based version:
The body region of the loop contains `subtensor` operations applied to
every tensor argument of TiledLoopOp.
The body region must contain exactly one block that terminates with The body region must contain exactly one block that terminates with
`linalg.yield` with the operands resulting from `subtensor_insert` `linalg.yield` with the operands resulting from `subtensor_insert`
operations. operations.
Parsing TiledLoopOp will set all elements of the `iterator_types` attribute
to "parallel" type, when it is absent from the custom format.
Example: Example:
```mlir ```mlir
linalg.tiled_loop (%i) = (%c0) to (%c24) step (%c4) %0 = linalg.tiled_loop (%i) = (%c0) to (%c24) step (%c4)
ins(%lhs, %rhs : tensor<24x64xi8>, tensor<24x64xi8>) ins(%lhs, %rhs : tensor<24x64xi8>, tensor<24x64xi8>)
outs(%out : tensor<24x64xi8>) outs(%out : tensor<24x64xi8>)
iterators("parallel") { iterators("parallel") {
@ -528,13 +532,40 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
linalg.yield %result : tensor<24x64xi8> linalg.yield %result : tensor<24x64xi8>
} }
``` ```
MemRef-based version:
The body region of the loop contains `subview` operations applied to
every memref argument of TiledLoopOp.
The body region must contain exactly one block that terminates with
`linalg.yield` with no operands.
Example:
```mlir
linalg.tiled_loop (%i) = (%c0) to (%c24) step (%c4)
ins(%lhs, %rhs : memref<24x64xi8>, memref<24x64xi8>)
outs(%out : memref<24x64xi8>)
iterators("parallel") {
%lhs_sub = subview %lhs[%i, 0] [%c4, %c64] [1, 1]
: memref<24x64xi8> to memref<?x?xi8>
%rhs_sub = subview %rhs[%i, 0] [%c4, %c64] [1, 1]
: memref<24x64xi8> to memref<?x?xi8>
%out_sub = subview %out[%i, 0] [%c4, %c64] [1, 1]
: memref<24x64xi8> to memref<?x?xi8>
%result_sub = linalg.generic ...
linalg.yield
}
```
}]; }];
let arguments = (ins Variadic<Index>:$lowerBound, let arguments = (ins Variadic<Index>:$lowerBound,
Variadic<Index>:$upperBound, Variadic<Index>:$upperBound,
Variadic<Index>:$step, Variadic<Index>:$step,
Variadic<AnyRankedTensor>:$inputs, Variadic<LinalgOperand>:$inputs,
Variadic<AnyRankedTensor>:$outputs, Variadic<LinalgOperand>:$outputs,
ArrayAttr:$iterator_types); ArrayAttr:$iterator_types);
let results = (outs Variadic<AnyRankedTensor>:$results); let results = (outs Variadic<AnyRankedTensor>:$results);
let regions = (region SizedRegion<1>:$region); let regions = (region SizedRegion<1>:$region);
@ -542,7 +573,7 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
let builders = [ let builders = [
OpBuilder<(ins "ValueRange":$lowerBounds, "ValueRange":$upperBounds, OpBuilder<(ins "ValueRange":$lowerBounds, "ValueRange":$upperBounds,
"ValueRange":$steps, "ValueRange":$inputs, "ValueRange":$outputs, "ValueRange":$steps, "ValueRange":$inputs, "ValueRange":$outputs,
"ArrayRef<StringRef>":$iteratorTypes, "ArrayAttr":$iteratorTypes,
CArg<"function_ref<void (OpBuilder &, Location, ValueRange)>", CArg<"function_ref<void (OpBuilder &, Location, ValueRange)>",
"nullptr">:$bodyBuilderFn)>, "nullptr">:$bodyBuilderFn)>,
]; ];

View File

@ -496,8 +496,6 @@ def PoolingSumOp: SingleInputPoolingBase_Op<"pooling_sum"> {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Generic Linalg ops. // Generic Linalg ops.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
def LinalgOperand: AnyTypeOf<[AnyRankedTensor, AnyStridedMemRef]>;
class LinalgOperandOfRank<int rank>: Type< class LinalgOperandOfRank<int rank>: Type<
And<[ And<[
LinalgOperand.predicate, LinalgOperand.predicate,

View File

@ -1744,7 +1744,7 @@ static LogicalResult verify(linalg::YieldOp op) {
void TiledLoopOp::build( void TiledLoopOp::build(
OpBuilder &builder, OperationState &result, ValueRange lowerBounds, OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
ValueRange upperBounds, ValueRange steps, ValueRange inputs, ValueRange upperBounds, ValueRange steps, ValueRange inputs,
ValueRange outputs, ArrayRef<StringRef> iteratorTypes, ValueRange outputs, ArrayAttr iteratorTypes,
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) { function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
result.addOperands(lowerBounds); result.addOperands(lowerBounds);
result.addOperands(upperBounds); result.addOperands(upperBounds);
@ -1758,9 +1758,14 @@ void TiledLoopOp::build(
static_cast<int32_t>(steps.size()), static_cast<int32_t>(steps.size()),
static_cast<int32_t>(inputs.size()), static_cast<int32_t>(inputs.size()),
static_cast<int32_t>(outputs.size())})); static_cast<int32_t>(outputs.size())}));
result.addAttribute(getIteratorTypesAttrName(), result.addAttribute(getIteratorTypesAttrName(), iteratorTypes);
builder.getStrArrayAttr(iteratorTypes));
result.addTypes(outputs.getTypes()); // Add output types for `RankedTensorType` output arguments.
for (Value output : outputs) {
Type outputType = output.getType();
if (outputType.isa<RankedTensorType>())
result.addTypes(outputType);
}
OpBuilder::InsertionGuard guard(builder); OpBuilder::InsertionGuard guard(builder);
unsigned numIVs = steps.size(); unsigned numIVs = steps.size();
@ -1771,8 +1776,8 @@ void TiledLoopOp::build(
if (bodyBuilderFn) { if (bodyBuilderFn) {
builder.setInsertionPointToStart(bodyBlock); builder.setInsertionPointToStart(bodyBlock);
bodyBuilderFn(builder, result.location, bodyBlock->getArguments()); bodyBuilderFn(builder, result.location, bodyBlock->getArguments());
TiledLoopOp::ensureTerminator(*bodyRegion, builder, result.location);
} }
TiledLoopOp::ensureTerminator(*bodyRegion, builder, result.location);
} }
static void print(OpAsmPrinter &p, TiledLoopOp op) { static void print(OpAsmPrinter &p, TiledLoopOp op) {