forked from OSchip/llvm-project
[mlir][linalg] Add support for memref inputs/outputs for `linalg.tiled_loop`.
Also use `ArrayAttr` to pass iterator pass to the TiledLoopOp builder. Differential Revision: https://reviews.llvm.org/D98871
This commit is contained in:
parent
c539be1dcb
commit
283799157e
|
@ -15,6 +15,8 @@
|
||||||
|
|
||||||
include "mlir/IR/OpBase.td"
|
include "mlir/IR/OpBase.td"
|
||||||
|
|
||||||
|
def LinalgOperand: AnyTypeOf<[AnyRankedTensor, AnyStridedMemRef]>;
|
||||||
|
|
||||||
def Linalg_Dialect : Dialect {
|
def Linalg_Dialect : Dialect {
|
||||||
let name = "linalg";
|
let name = "linalg";
|
||||||
let description = [{
|
let description = [{
|
||||||
|
|
|
@ -496,21 +496,25 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
|
||||||
let summary = "Linalg tiled loop operation";
|
let summary = "Linalg tiled loop operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
This is a loop-like operation with additional properties. The arguments
|
This is a loop-like operation with additional properties. The arguments
|
||||||
also include the input and the output tensors and the attributes to specify
|
also include the input and the output tensors or memrefs and the attributes
|
||||||
the iterator types. The body region of the loop contains `subtensor`
|
to specify the iterator types.
|
||||||
operations applied to every tensor argument of TiledLoopOp.
|
|
||||||
|
Parsing TiledLoopOp will set all elements of the `iterator_types` attribute
|
||||||
|
to "parallel" type, when it is absent from the custom format.
|
||||||
|
|
||||||
|
Tensor-based version:
|
||||||
|
|
||||||
|
The body region of the loop contains `subtensor` operations applied to
|
||||||
|
every tensor argument of TiledLoopOp.
|
||||||
|
|
||||||
The body region must contain exactly one block that terminates with
|
The body region must contain exactly one block that terminates with
|
||||||
`linalg.yield` with the operands resulting from `subtensor_insert`
|
`linalg.yield` with the operands resulting from `subtensor_insert`
|
||||||
operations.
|
operations.
|
||||||
|
|
||||||
Parsing TiledLoopOp will set all elements of the `iterator_types` attribute
|
|
||||||
to "parallel" type, when it is absent from the custom format.
|
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
```mlir
|
```mlir
|
||||||
linalg.tiled_loop (%i) = (%c0) to (%c24) step (%c4)
|
%0 = linalg.tiled_loop (%i) = (%c0) to (%c24) step (%c4)
|
||||||
ins(%lhs, %rhs : tensor<24x64xi8>, tensor<24x64xi8>)
|
ins(%lhs, %rhs : tensor<24x64xi8>, tensor<24x64xi8>)
|
||||||
outs(%out : tensor<24x64xi8>)
|
outs(%out : tensor<24x64xi8>)
|
||||||
iterators("parallel") {
|
iterators("parallel") {
|
||||||
|
@ -528,13 +532,40 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
|
||||||
linalg.yield %result : tensor<24x64xi8>
|
linalg.yield %result : tensor<24x64xi8>
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
MemRef-based version:
|
||||||
|
|
||||||
|
The body region of the loop contains `subview` operations applied to
|
||||||
|
every memref argument of TiledLoopOp.
|
||||||
|
|
||||||
|
The body region must contain exactly one block that terminates with
|
||||||
|
`linalg.yield` with no operands.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```mlir
|
||||||
|
linalg.tiled_loop (%i) = (%c0) to (%c24) step (%c4)
|
||||||
|
ins(%lhs, %rhs : memref<24x64xi8>, memref<24x64xi8>)
|
||||||
|
outs(%out : memref<24x64xi8>)
|
||||||
|
iterators("parallel") {
|
||||||
|
%lhs_sub = subview %lhs[%i, 0] [%c4, %c64] [1, 1]
|
||||||
|
: memref<24x64xi8> to memref<?x?xi8>
|
||||||
|
%rhs_sub = subview %rhs[%i, 0] [%c4, %c64] [1, 1]
|
||||||
|
: memref<24x64xi8> to memref<?x?xi8>
|
||||||
|
%out_sub = subview %out[%i, 0] [%c4, %c64] [1, 1]
|
||||||
|
: memref<24x64xi8> to memref<?x?xi8>
|
||||||
|
|
||||||
|
%result_sub = linalg.generic ...
|
||||||
|
linalg.yield
|
||||||
|
}
|
||||||
|
```
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins Variadic<Index>:$lowerBound,
|
let arguments = (ins Variadic<Index>:$lowerBound,
|
||||||
Variadic<Index>:$upperBound,
|
Variadic<Index>:$upperBound,
|
||||||
Variadic<Index>:$step,
|
Variadic<Index>:$step,
|
||||||
Variadic<AnyRankedTensor>:$inputs,
|
Variadic<LinalgOperand>:$inputs,
|
||||||
Variadic<AnyRankedTensor>:$outputs,
|
Variadic<LinalgOperand>:$outputs,
|
||||||
ArrayAttr:$iterator_types);
|
ArrayAttr:$iterator_types);
|
||||||
let results = (outs Variadic<AnyRankedTensor>:$results);
|
let results = (outs Variadic<AnyRankedTensor>:$results);
|
||||||
let regions = (region SizedRegion<1>:$region);
|
let regions = (region SizedRegion<1>:$region);
|
||||||
|
@ -542,7 +573,7 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
|
||||||
let builders = [
|
let builders = [
|
||||||
OpBuilder<(ins "ValueRange":$lowerBounds, "ValueRange":$upperBounds,
|
OpBuilder<(ins "ValueRange":$lowerBounds, "ValueRange":$upperBounds,
|
||||||
"ValueRange":$steps, "ValueRange":$inputs, "ValueRange":$outputs,
|
"ValueRange":$steps, "ValueRange":$inputs, "ValueRange":$outputs,
|
||||||
"ArrayRef<StringRef>":$iteratorTypes,
|
"ArrayAttr":$iteratorTypes,
|
||||||
CArg<"function_ref<void (OpBuilder &, Location, ValueRange)>",
|
CArg<"function_ref<void (OpBuilder &, Location, ValueRange)>",
|
||||||
"nullptr">:$bodyBuilderFn)>,
|
"nullptr">:$bodyBuilderFn)>,
|
||||||
];
|
];
|
||||||
|
|
|
@ -496,8 +496,6 @@ def PoolingSumOp: SingleInputPoolingBase_Op<"pooling_sum"> {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Generic Linalg ops.
|
// Generic Linalg ops.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
def LinalgOperand: AnyTypeOf<[AnyRankedTensor, AnyStridedMemRef]>;
|
|
||||||
|
|
||||||
class LinalgOperandOfRank<int rank>: Type<
|
class LinalgOperandOfRank<int rank>: Type<
|
||||||
And<[
|
And<[
|
||||||
LinalgOperand.predicate,
|
LinalgOperand.predicate,
|
||||||
|
|
|
@ -1744,7 +1744,7 @@ static LogicalResult verify(linalg::YieldOp op) {
|
||||||
void TiledLoopOp::build(
|
void TiledLoopOp::build(
|
||||||
OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
|
OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
|
||||||
ValueRange upperBounds, ValueRange steps, ValueRange inputs,
|
ValueRange upperBounds, ValueRange steps, ValueRange inputs,
|
||||||
ValueRange outputs, ArrayRef<StringRef> iteratorTypes,
|
ValueRange outputs, ArrayAttr iteratorTypes,
|
||||||
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
|
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
|
||||||
result.addOperands(lowerBounds);
|
result.addOperands(lowerBounds);
|
||||||
result.addOperands(upperBounds);
|
result.addOperands(upperBounds);
|
||||||
|
@ -1758,9 +1758,14 @@ void TiledLoopOp::build(
|
||||||
static_cast<int32_t>(steps.size()),
|
static_cast<int32_t>(steps.size()),
|
||||||
static_cast<int32_t>(inputs.size()),
|
static_cast<int32_t>(inputs.size()),
|
||||||
static_cast<int32_t>(outputs.size())}));
|
static_cast<int32_t>(outputs.size())}));
|
||||||
result.addAttribute(getIteratorTypesAttrName(),
|
result.addAttribute(getIteratorTypesAttrName(), iteratorTypes);
|
||||||
builder.getStrArrayAttr(iteratorTypes));
|
|
||||||
result.addTypes(outputs.getTypes());
|
// Add output types for `RankedTensorType` output arguments.
|
||||||
|
for (Value output : outputs) {
|
||||||
|
Type outputType = output.getType();
|
||||||
|
if (outputType.isa<RankedTensorType>())
|
||||||
|
result.addTypes(outputType);
|
||||||
|
}
|
||||||
|
|
||||||
OpBuilder::InsertionGuard guard(builder);
|
OpBuilder::InsertionGuard guard(builder);
|
||||||
unsigned numIVs = steps.size();
|
unsigned numIVs = steps.size();
|
||||||
|
@ -1771,8 +1776,8 @@ void TiledLoopOp::build(
|
||||||
if (bodyBuilderFn) {
|
if (bodyBuilderFn) {
|
||||||
builder.setInsertionPointToStart(bodyBlock);
|
builder.setInsertionPointToStart(bodyBlock);
|
||||||
bodyBuilderFn(builder, result.location, bodyBlock->getArguments());
|
bodyBuilderFn(builder, result.location, bodyBlock->getArguments());
|
||||||
|
TiledLoopOp::ensureTerminator(*bodyRegion, builder, result.location);
|
||||||
}
|
}
|
||||||
TiledLoopOp::ensureTerminator(*bodyRegion, builder, result.location);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static void print(OpAsmPrinter &p, TiledLoopOp op) {
|
static void print(OpAsmPrinter &p, TiledLoopOp op) {
|
||||||
|
|
Loading…
Reference in New Issue