forked from OSchip/llvm-project
[mlir][linalg] Add support for memref inputs/outputs for `linalg.tiled_loop`.
Also use `ArrayAttr` to pass iterator pass to the TiledLoopOp builder. Differential Revision: https://reviews.llvm.org/D98871
This commit is contained in:
parent
c539be1dcb
commit
283799157e
|
@ -15,6 +15,8 @@
|
|||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def LinalgOperand: AnyTypeOf<[AnyRankedTensor, AnyStridedMemRef]>;
|
||||
|
||||
def Linalg_Dialect : Dialect {
|
||||
let name = "linalg";
|
||||
let description = [{
|
||||
|
|
|
@ -496,21 +496,25 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
|
|||
let summary = "Linalg tiled loop operation";
|
||||
let description = [{
|
||||
This is a loop-like operation with additional properties. The arguments
|
||||
also include the input and the output tensors and the attributes to specify
|
||||
the iterator types. The body region of the loop contains `subtensor`
|
||||
operations applied to every tensor argument of TiledLoopOp.
|
||||
also include the input and the output tensors or memrefs and the attributes
|
||||
to specify the iterator types.
|
||||
|
||||
Parsing TiledLoopOp will set all elements of the `iterator_types` attribute
|
||||
to "parallel" type, when it is absent from the custom format.
|
||||
|
||||
Tensor-based version:
|
||||
|
||||
The body region of the loop contains `subtensor` operations applied to
|
||||
every tensor argument of TiledLoopOp.
|
||||
|
||||
The body region must contain exactly one block that terminates with
|
||||
`linalg.yield` with the operands resulting from `subtensor_insert`
|
||||
operations.
|
||||
|
||||
Parsing TiledLoopOp will set all elements of the `iterator_types` attribute
|
||||
to "parallel" type, when it is absent from the custom format.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
linalg.tiled_loop (%i) = (%c0) to (%c24) step (%c4)
|
||||
%0 = linalg.tiled_loop (%i) = (%c0) to (%c24) step (%c4)
|
||||
ins(%lhs, %rhs : tensor<24x64xi8>, tensor<24x64xi8>)
|
||||
outs(%out : tensor<24x64xi8>)
|
||||
iterators("parallel") {
|
||||
|
@ -528,13 +532,40 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
|
|||
linalg.yield %result : tensor<24x64xi8>
|
||||
}
|
||||
```
|
||||
|
||||
MemRef-based version:
|
||||
|
||||
The body region of the loop contains `subview` operations applied to
|
||||
every memref argument of TiledLoopOp.
|
||||
|
||||
The body region must contain exactly one block that terminates with
|
||||
`linalg.yield` with no operands.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
linalg.tiled_loop (%i) = (%c0) to (%c24) step (%c4)
|
||||
ins(%lhs, %rhs : memref<24x64xi8>, memref<24x64xi8>)
|
||||
outs(%out : memref<24x64xi8>)
|
||||
iterators("parallel") {
|
||||
%lhs_sub = subview %lhs[%i, 0] [%c4, %c64] [1, 1]
|
||||
: memref<24x64xi8> to memref<?x?xi8>
|
||||
%rhs_sub = subview %rhs[%i, 0] [%c4, %c64] [1, 1]
|
||||
: memref<24x64xi8> to memref<?x?xi8>
|
||||
%out_sub = subview %out[%i, 0] [%c4, %c64] [1, 1]
|
||||
: memref<24x64xi8> to memref<?x?xi8>
|
||||
|
||||
%result_sub = linalg.generic ...
|
||||
linalg.yield
|
||||
}
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins Variadic<Index>:$lowerBound,
|
||||
Variadic<Index>:$upperBound,
|
||||
Variadic<Index>:$step,
|
||||
Variadic<AnyRankedTensor>:$inputs,
|
||||
Variadic<AnyRankedTensor>:$outputs,
|
||||
Variadic<LinalgOperand>:$inputs,
|
||||
Variadic<LinalgOperand>:$outputs,
|
||||
ArrayAttr:$iterator_types);
|
||||
let results = (outs Variadic<AnyRankedTensor>:$results);
|
||||
let regions = (region SizedRegion<1>:$region);
|
||||
|
@ -542,7 +573,7 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
|
|||
let builders = [
|
||||
OpBuilder<(ins "ValueRange":$lowerBounds, "ValueRange":$upperBounds,
|
||||
"ValueRange":$steps, "ValueRange":$inputs, "ValueRange":$outputs,
|
||||
"ArrayRef<StringRef>":$iteratorTypes,
|
||||
"ArrayAttr":$iteratorTypes,
|
||||
CArg<"function_ref<void (OpBuilder &, Location, ValueRange)>",
|
||||
"nullptr">:$bodyBuilderFn)>,
|
||||
];
|
||||
|
|
|
@ -496,8 +496,6 @@ def PoolingSumOp: SingleInputPoolingBase_Op<"pooling_sum"> {
|
|||
//===----------------------------------------------------------------------===//
|
||||
// Generic Linalg ops.
|
||||
//===----------------------------------------------------------------------===//
|
||||
def LinalgOperand: AnyTypeOf<[AnyRankedTensor, AnyStridedMemRef]>;
|
||||
|
||||
class LinalgOperandOfRank<int rank>: Type<
|
||||
And<[
|
||||
LinalgOperand.predicate,
|
||||
|
|
|
@ -1744,7 +1744,7 @@ static LogicalResult verify(linalg::YieldOp op) {
|
|||
void TiledLoopOp::build(
|
||||
OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
|
||||
ValueRange upperBounds, ValueRange steps, ValueRange inputs,
|
||||
ValueRange outputs, ArrayRef<StringRef> iteratorTypes,
|
||||
ValueRange outputs, ArrayAttr iteratorTypes,
|
||||
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
|
||||
result.addOperands(lowerBounds);
|
||||
result.addOperands(upperBounds);
|
||||
|
@ -1758,9 +1758,14 @@ void TiledLoopOp::build(
|
|||
static_cast<int32_t>(steps.size()),
|
||||
static_cast<int32_t>(inputs.size()),
|
||||
static_cast<int32_t>(outputs.size())}));
|
||||
result.addAttribute(getIteratorTypesAttrName(),
|
||||
builder.getStrArrayAttr(iteratorTypes));
|
||||
result.addTypes(outputs.getTypes());
|
||||
result.addAttribute(getIteratorTypesAttrName(), iteratorTypes);
|
||||
|
||||
// Add output types for `RankedTensorType` output arguments.
|
||||
for (Value output : outputs) {
|
||||
Type outputType = output.getType();
|
||||
if (outputType.isa<RankedTensorType>())
|
||||
result.addTypes(outputType);
|
||||
}
|
||||
|
||||
OpBuilder::InsertionGuard guard(builder);
|
||||
unsigned numIVs = steps.size();
|
||||
|
@ -1771,8 +1776,8 @@ void TiledLoopOp::build(
|
|||
if (bodyBuilderFn) {
|
||||
builder.setInsertionPointToStart(bodyBlock);
|
||||
bodyBuilderFn(builder, result.location, bodyBlock->getArguments());
|
||||
TiledLoopOp::ensureTerminator(*bodyRegion, builder, result.location);
|
||||
}
|
||||
TiledLoopOp::ensureTerminator(*bodyRegion, builder, result.location);
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, TiledLoopOp op) {
|
||||
|
|
Loading…
Reference in New Issue