[MLIR][Shape] Allow `shape.reduce` to operate on extent tensors

Allow `shape.reduce` to take both `shape.shape` and `tensor<?xindex>` as an
argument.

Differential Revision: https://reviews.llvm.org/D83943
This commit is contained in:
Frederik Gossen 2020-07-16 13:52:42 +00:00
parent 920e127e02
commit 0eb50e614c
6 changed files with 73 additions and 36 deletions

View File

@ -338,23 +338,26 @@ def Shape_NumElementsOp : Shape_Op<"num_elements", [NoSideEffect]> {
def Shape_ReduceOp : Shape_Op<"reduce",
[SingleBlockImplicitTerminator<"YieldOp">]> {
let summary = "Returns an expression reduced over a shape";
let summary = "Returns an expression reduced over a shape or extent tensor";
let description = [{
An operation that takes as input a shape, number of initial values and has a
region/function that is applied repeatedly for every dimension of the shape.
An operation that takes as input a shape or extent tensor, and a number of
initial values. This operation has a region/function that is applied
repeatedly for every extent of the input. Starting with the initial values,
the individual extents are then aggregated as defined by the associated
region.
Conceptually this op performs the following reduction:
```
res[] = init;
for (int i = 0, e = shape.rank(); i != e; ++i) {
for (int i = 0, i < shape.rank(); i++) {
res = fn(i, shape[i], res[0], ..., res[n]);
}
```
Where fn is provided by the user and the result of the reduce op is the
Where `fn` is provided by the user and the result of the reduce op is the
last computed output of the reduce function. As an example, computing the
number of elements
number of elements can be defined as follows:
```mlir
func @reduce(%shape : !shape.shape, %init : !shape.size) -> !shape.size {
@ -367,11 +370,10 @@ def Shape_ReduceOp : Shape_Op<"reduce",
return %num_elements : !shape.size
}
```
If the shape is unranked, then the results of the op is also unranked.
}];
let arguments = (ins Shape_ShapeType:$shape, Variadic<AnyType>:$initVals);
let arguments = (ins Shape_ShapeOrExtentTensorType:$shape,
Variadic<AnyType>:$initVals);
let results = (outs Variadic<AnyType>:$result);
let regions = (region SizedRegion<1>:$region);

View File

@ -721,18 +721,31 @@ static LogicalResult verify(ReduceOp op) {
// Verify block arg types.
Block &block = op.region().front();
// The block takes index, extent, and aggregated values as arguments.
auto blockArgsCount = op.initVals().size() + 2;
if (block.getNumArguments() != blockArgsCount)
return op.emitOpError() << "ReduceOp body is expected to have "
<< blockArgsCount << " arguments";
if (block.getArgument(0).getType() != IndexType::get(op.getContext()))
// The first block argument is the index and must always be of type `index`.
if (!block.getArgument(0).getType().isa<IndexType>())
return op.emitOpError(
"argument 0 of ReduceOp body is expected to be of IndexType");
if (block.getArgument(1).getType() != SizeType::get(op.getContext()))
return op.emitOpError(
"argument 1 of ReduceOp body is expected to be of SizeType");
// The second block argument is the extent and must be of type `size` or
// `index`, depending on whether the reduce operation is applied to a shape or
// to an extent tensor.
Type extentTy = block.getArgument(1).getType();
if (op.shape().getType().isa<ShapeType>()) {
if (!extentTy.isa<SizeType>())
return op.emitOpError("argument 1 of ReduceOp body is expected to be of "
"SizeType if the ReduceOp operates on a ShapeType");
} else {
if (!extentTy.isa<IndexType>())
return op.emitOpError(
"argument 1 of ReduceOp body is expected to be of IndexType if the "
"ReduceOp operates on an extent tensor");
}
for (auto type : llvm::enumerate(op.initVals()))
if (block.getArgument(type.index() + 2).getType() != type.value().getType())
@ -743,17 +756,18 @@ static LogicalResult verify(ReduceOp op) {
}
static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) {
auto *ctx = parser.getBuilder().getContext();
// Parse operands.
SmallVector<OpAsmParser::OperandType, 3> operands;
Type shapeOrExtentTensorType;
if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
OpAsmParser::Delimiter::Paren) ||
parser.parseColonType(shapeOrExtentTensorType) ||
parser.parseOptionalArrowTypeList(result.types))
return failure();
// Resolve operands.
auto initVals = llvm::makeArrayRef(operands).drop_front();
if (parser.resolveOperand(operands.front(), ShapeType::get(ctx),
if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType,
result.operands) ||
parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
result.operands))
@ -773,7 +787,7 @@ static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) {
static void print(OpAsmPrinter &p, ReduceOp op) {
p << op.getOperationName() << '(' << op.shape() << ", " << op.initVals()
<< ") ";
<< ") : " << op.shape().getType();
p.printOptionalArrowTypeList(op.getResultTypes());
p.printRegion(op.region());
p.printOptionalAttrDict(op.getAttrs());

View File

@ -1,10 +1,10 @@
// RUN: mlir-opt -convert-shape-to-scf -split-input-file %s | FileCheck %s
// CHECK-LABEL: shape_reduce
// CHECK-SAME: [[SHAPE:%.*]]: !shape.shape) -> !shape.size {
// CHECK-LABEL: @shape_reduce
// CHECK-SAME: ([[SHAPE:%.*]]: !shape.shape) -> !shape.size
func @shape_reduce(%shape : !shape.shape) -> !shape.size {
%init = shape.const_size 1
%num_elements = shape.reduce(%shape, %init) -> !shape.size {
%num_elements = shape.reduce(%shape, %init) : !shape.shape -> !shape.size {
^bb0(%index: index, %dim: !shape.size, %acc: !shape.size):
%new_acc = shape.mul %acc, %dim
shape.yield %new_acc : !shape.size

View File

@ -2,7 +2,7 @@
func @reduce_op_args_num_mismatch(%shape : !shape.shape, %init : !shape.size) {
// expected-error@+1 {{ReduceOp body is expected to have 3 arguments}}
%num_elements = shape.reduce(%shape, %init) -> !shape.size {
%num_elements = shape.reduce(%shape, %init) : !shape.shape -> !shape.size {
^bb0(%index: index, %dim: !shape.size):
shape.yield %dim : !shape.size
}
@ -12,7 +12,7 @@ func @reduce_op_args_num_mismatch(%shape : !shape.shape, %init : !shape.size) {
func @reduce_op_arg0_wrong_type(%shape : !shape.shape, %init : !shape.size) {
// expected-error@+1 {{argument 0 of ReduceOp body is expected to be of IndexType}}
%num_elements = shape.reduce(%shape, %init) -> !shape.size {
%num_elements = shape.reduce(%shape, %init) : !shape.shape -> !shape.size {
^bb0(%index: f32, %dim: !shape.size, %acc: !shape.size):
%new_acc = "shape.add"(%acc, %dim)
: (!shape.size, !shape.size) -> !shape.size
@ -23,8 +23,8 @@ func @reduce_op_arg0_wrong_type(%shape : !shape.shape, %init : !shape.size) {
// -----
func @reduce_op_arg1_wrong_type(%shape : !shape.shape, %init : !shape.size) {
// expected-error@+1 {{argument 1 of ReduceOp body is expected to be of SizeType}}
%num_elements = shape.reduce(%shape, %init) -> !shape.size {
// expected-error@+1 {{argument 1 of ReduceOp body is expected to be of SizeType if the ReduceOp operates on a ShapeType}}
%num_elements = shape.reduce(%shape, %init) : !shape.shape -> !shape.size {
^bb0(%index: index, %dim: f32, %lci: !shape.size):
shape.yield
}
@ -32,9 +32,19 @@ func @reduce_op_arg1_wrong_type(%shape : !shape.shape, %init : !shape.size) {
// -----
func @reduce_op_arg1_wrong_type(%shape : tensor<?xindex>, %init : index) {
// expected-error@+1 {{argument 1 of ReduceOp body is expected to be of IndexType if the ReduceOp operates on an extent tensor}}
%num_elements = shape.reduce(%shape, %init) : tensor<?xindex> -> index {
^bb0(%index: index, %dim: f32, %lci: index):
shape.yield
}
}
// -----
func @reduce_op_init_type_mismatch(%shape : !shape.shape, %init : f32) {
// expected-error@+1 {{type mismatch between argument 2 of ReduceOp body and initial value 0}}
%num_elements = shape.reduce(%shape, %init) -> f32 {
%num_elements = shape.reduce(%shape, %init) : !shape.shape -> f32 {
^bb0(%index: index, %dim: !shape.size, %lci: !shape.size):
shape.yield
}
@ -44,7 +54,7 @@ func @reduce_op_init_type_mismatch(%shape : !shape.shape, %init : f32) {
func @yield_op_args_num_mismatch(%shape : !shape.shape, %init : !shape.size) {
// expected-error@+3 {{number of operands does not match number of results of its parent}}
%num_elements = shape.reduce(%shape, %init) -> !shape.size {
%num_elements = shape.reduce(%shape, %init) : !shape.shape -> !shape.size {
^bb0(%index: index, %dim: !shape.size, %lci: !shape.size):
shape.yield %dim, %dim : !shape.size, !shape.size
}
@ -54,7 +64,7 @@ func @yield_op_args_num_mismatch(%shape : !shape.shape, %init : !shape.size) {
func @yield_op_type_mismatch(%shape : !shape.shape, %init : !shape.size) {
// expected-error@+4 {{types mismatch between yield op and its parent}}
%num_elements = shape.reduce(%shape, %init) -> !shape.size {
%num_elements = shape.reduce(%shape, %init) : !shape.shape -> !shape.size {
^bb0(%index: index, %dim: !shape.size, %lci: !shape.size):
%c0 = constant 1 : index
shape.yield %c0 : index

View File

@ -6,15 +6,26 @@
// CHECK-LABEL: shape_num_elements
func @shape_num_elements(%shape : !shape.shape) -> !shape.size {
%init = shape.const_size 0
%num_elements = shape.reduce(%shape, %init) -> !shape.size {
^bb0(%index: index, %dim: !shape.size, %lci: !shape.size):
%acc = shape.add %lci, %dim
shape.yield %acc : !shape.size
%init = shape.const_size 1
%num_elements = shape.reduce(%shape, %init) : !shape.shape -> !shape.size {
^bb0(%index : index, %extent : !shape.size, %acc : !shape.size):
%acc_next = shape.mul %acc, %extent
shape.yield %acc_next : !shape.size
}
return %num_elements : !shape.size
}
// CHECK-LABEL: extent_tensor_num_elements
func @extent_tensor_num_elements(%shape : tensor<?xindex>) -> index {
%init = constant 1 : index
%num_elements = shape.reduce(%shape, %init) : tensor<?xindex> -> index {
^bb0(%index : index, %extent : index, %acc : index):
%acc_next = muli %acc, %extent : index
shape.yield %acc_next : index
}
return %num_elements : index
}
func @test_shape_num_elements_unknown() {
%0 = "shape.unknown_shape"() : () -> !shape.shape
%1 = call @shape_num_elements(%0) : (!shape.shape) -> (!shape.size)

View File

@ -1,16 +1,16 @@
// RUN: mlir-opt -shape-to-shape-lowering -split-input-file %s | FileCheck %s
// CHECK-LABEL: func @num_elements_to_reduce(
// CHECK-SAME: [[ARG:%.*]]: !shape.shape) -> [[SIZE_TY:!.*]] {
// CHECK-SAME: [[ARG:%.*]]: !shape.shape) -> !shape.size {
func @num_elements_to_reduce(%shape : !shape.shape) -> !shape.size {
%num_elements = shape.num_elements %shape
return %num_elements : !shape.size
}
// CHECK: [[C1:%.*]] = shape.const_size 1
// CHECK: [[NUM_ELEMENTS:%.*]] = shape.reduce([[ARG]], [[C1]]) -> [[SIZE_TY]]
// CHECK: ^bb0({{.*}}: index, [[DIM:%.*]]: [[SIZE_TY]], [[ACC:%.*]]: [[SIZE_TY]]
// CHECK: [[NUM_ELEMENTS:%.*]] = shape.reduce([[ARG]], [[C1]]) : !shape.shape -> !shape.size
// CHECK: ^bb0({{.*}}: index, [[DIM:%.*]]: !shape.size, [[ACC:%.*]]: !shape.size
// CHECK: [[NEW_ACC:%.*]] = shape.mul [[DIM]], [[ACC]]
// CHECK: shape.yield [[NEW_ACC]] : [[SIZE_TY]]
// CHECK: shape.yield [[NEW_ACC]] : !shape.size
// CHECK: }
// CHECK: return [[NUM_ELEMENTS]] : [[SIZE_TY]]
// CHECK: return [[NUM_ELEMENTS]] : !shape.size