forked from OSchip/llvm-project
[MLIR][Shape] Allow `shape.reduce` to operate on extent tensors
Allow `shape.reduce` to take both `shape.shape` and `tensor<?xindex>` as an argument. Differential Revision: https://reviews.llvm.org/D83943
This commit is contained in:
parent
920e127e02
commit
0eb50e614c
|
@ -338,23 +338,26 @@ def Shape_NumElementsOp : Shape_Op<"num_elements", [NoSideEffect]> {
|
|||
|
||||
def Shape_ReduceOp : Shape_Op<"reduce",
|
||||
[SingleBlockImplicitTerminator<"YieldOp">]> {
|
||||
let summary = "Returns an expression reduced over a shape";
|
||||
let summary = "Returns an expression reduced over a shape or extent tensor";
|
||||
let description = [{
|
||||
An operation that takes as input a shape, number of initial values and has a
|
||||
region/function that is applied repeatedly for every dimension of the shape.
|
||||
An operation that takes as input a shape or extent tensor, and a number of
|
||||
initial values. This operation has a region/function that is applied
|
||||
repeatedly for every extent of the input. Starting with the initial values,
|
||||
the individual extents are then aggregated as defined by the associated
|
||||
region.
|
||||
|
||||
Conceptually this op performs the following reduction:
|
||||
|
||||
```
|
||||
res[] = init;
|
||||
for (int i = 0, e = shape.rank(); i != e; ++i) {
|
||||
for (int i = 0, i < shape.rank(); i++) {
|
||||
res = fn(i, shape[i], res[0], ..., res[n]);
|
||||
}
|
||||
```
|
||||
|
||||
Where fn is provided by the user and the result of the reduce op is the
|
||||
Where `fn` is provided by the user and the result of the reduce op is the
|
||||
last computed output of the reduce function. As an example, computing the
|
||||
number of elements
|
||||
number of elements can be defined as follows:
|
||||
|
||||
```mlir
|
||||
func @reduce(%shape : !shape.shape, %init : !shape.size) -> !shape.size {
|
||||
|
@ -367,11 +370,10 @@ def Shape_ReduceOp : Shape_Op<"reduce",
|
|||
return %num_elements : !shape.size
|
||||
}
|
||||
```
|
||||
|
||||
If the shape is unranked, then the results of the op is also unranked.
|
||||
}];
|
||||
|
||||
let arguments = (ins Shape_ShapeType:$shape, Variadic<AnyType>:$initVals);
|
||||
let arguments = (ins Shape_ShapeOrExtentTensorType:$shape,
|
||||
Variadic<AnyType>:$initVals);
|
||||
let results = (outs Variadic<AnyType>:$result);
|
||||
let regions = (region SizedRegion<1>:$region);
|
||||
|
||||
|
|
|
@ -721,18 +721,31 @@ static LogicalResult verify(ReduceOp op) {
|
|||
// Verify block arg types.
|
||||
Block &block = op.region().front();
|
||||
|
||||
// The block takes index, extent, and aggregated values as arguments.
|
||||
auto blockArgsCount = op.initVals().size() + 2;
|
||||
if (block.getNumArguments() != blockArgsCount)
|
||||
return op.emitOpError() << "ReduceOp body is expected to have "
|
||||
<< blockArgsCount << " arguments";
|
||||
|
||||
if (block.getArgument(0).getType() != IndexType::get(op.getContext()))
|
||||
// The first block argument is the index and must always be of type `index`.
|
||||
if (!block.getArgument(0).getType().isa<IndexType>())
|
||||
return op.emitOpError(
|
||||
"argument 0 of ReduceOp body is expected to be of IndexType");
|
||||
|
||||
if (block.getArgument(1).getType() != SizeType::get(op.getContext()))
|
||||
return op.emitOpError(
|
||||
"argument 1 of ReduceOp body is expected to be of SizeType");
|
||||
// The second block argument is the extent and must be of type `size` or
|
||||
// `index`, depending on whether the reduce operation is applied to a shape or
|
||||
// to an extent tensor.
|
||||
Type extentTy = block.getArgument(1).getType();
|
||||
if (op.shape().getType().isa<ShapeType>()) {
|
||||
if (!extentTy.isa<SizeType>())
|
||||
return op.emitOpError("argument 1 of ReduceOp body is expected to be of "
|
||||
"SizeType if the ReduceOp operates on a ShapeType");
|
||||
} else {
|
||||
if (!extentTy.isa<IndexType>())
|
||||
return op.emitOpError(
|
||||
"argument 1 of ReduceOp body is expected to be of IndexType if the "
|
||||
"ReduceOp operates on an extent tensor");
|
||||
}
|
||||
|
||||
for (auto type : llvm::enumerate(op.initVals()))
|
||||
if (block.getArgument(type.index() + 2).getType() != type.value().getType())
|
||||
|
@ -743,17 +756,18 @@ static LogicalResult verify(ReduceOp op) {
|
|||
}
|
||||
|
||||
static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) {
|
||||
auto *ctx = parser.getBuilder().getContext();
|
||||
// Parse operands.
|
||||
SmallVector<OpAsmParser::OperandType, 3> operands;
|
||||
Type shapeOrExtentTensorType;
|
||||
if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
|
||||
OpAsmParser::Delimiter::Paren) ||
|
||||
parser.parseColonType(shapeOrExtentTensorType) ||
|
||||
parser.parseOptionalArrowTypeList(result.types))
|
||||
return failure();
|
||||
|
||||
// Resolve operands.
|
||||
auto initVals = llvm::makeArrayRef(operands).drop_front();
|
||||
if (parser.resolveOperand(operands.front(), ShapeType::get(ctx),
|
||||
if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType,
|
||||
result.operands) ||
|
||||
parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
|
||||
result.operands))
|
||||
|
@ -773,7 +787,7 @@ static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) {
|
|||
|
||||
static void print(OpAsmPrinter &p, ReduceOp op) {
|
||||
p << op.getOperationName() << '(' << op.shape() << ", " << op.initVals()
|
||||
<< ") ";
|
||||
<< ") : " << op.shape().getType();
|
||||
p.printOptionalArrowTypeList(op.getResultTypes());
|
||||
p.printRegion(op.region());
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
// RUN: mlir-opt -convert-shape-to-scf -split-input-file %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: shape_reduce
|
||||
// CHECK-SAME: [[SHAPE:%.*]]: !shape.shape) -> !shape.size {
|
||||
// CHECK-LABEL: @shape_reduce
|
||||
// CHECK-SAME: ([[SHAPE:%.*]]: !shape.shape) -> !shape.size
|
||||
func @shape_reduce(%shape : !shape.shape) -> !shape.size {
|
||||
%init = shape.const_size 1
|
||||
%num_elements = shape.reduce(%shape, %init) -> !shape.size {
|
||||
%num_elements = shape.reduce(%shape, %init) : !shape.shape -> !shape.size {
|
||||
^bb0(%index: index, %dim: !shape.size, %acc: !shape.size):
|
||||
%new_acc = shape.mul %acc, %dim
|
||||
shape.yield %new_acc : !shape.size
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
func @reduce_op_args_num_mismatch(%shape : !shape.shape, %init : !shape.size) {
|
||||
// expected-error@+1 {{ReduceOp body is expected to have 3 arguments}}
|
||||
%num_elements = shape.reduce(%shape, %init) -> !shape.size {
|
||||
%num_elements = shape.reduce(%shape, %init) : !shape.shape -> !shape.size {
|
||||
^bb0(%index: index, %dim: !shape.size):
|
||||
shape.yield %dim : !shape.size
|
||||
}
|
||||
|
@ -12,7 +12,7 @@ func @reduce_op_args_num_mismatch(%shape : !shape.shape, %init : !shape.size) {
|
|||
|
||||
func @reduce_op_arg0_wrong_type(%shape : !shape.shape, %init : !shape.size) {
|
||||
// expected-error@+1 {{argument 0 of ReduceOp body is expected to be of IndexType}}
|
||||
%num_elements = shape.reduce(%shape, %init) -> !shape.size {
|
||||
%num_elements = shape.reduce(%shape, %init) : !shape.shape -> !shape.size {
|
||||
^bb0(%index: f32, %dim: !shape.size, %acc: !shape.size):
|
||||
%new_acc = "shape.add"(%acc, %dim)
|
||||
: (!shape.size, !shape.size) -> !shape.size
|
||||
|
@ -23,8 +23,8 @@ func @reduce_op_arg0_wrong_type(%shape : !shape.shape, %init : !shape.size) {
|
|||
// -----
|
||||
|
||||
func @reduce_op_arg1_wrong_type(%shape : !shape.shape, %init : !shape.size) {
|
||||
// expected-error@+1 {{argument 1 of ReduceOp body is expected to be of SizeType}}
|
||||
%num_elements = shape.reduce(%shape, %init) -> !shape.size {
|
||||
// expected-error@+1 {{argument 1 of ReduceOp body is expected to be of SizeType if the ReduceOp operates on a ShapeType}}
|
||||
%num_elements = shape.reduce(%shape, %init) : !shape.shape -> !shape.size {
|
||||
^bb0(%index: index, %dim: f32, %lci: !shape.size):
|
||||
shape.yield
|
||||
}
|
||||
|
@ -32,9 +32,19 @@ func @reduce_op_arg1_wrong_type(%shape : !shape.shape, %init : !shape.size) {
|
|||
|
||||
// -----
|
||||
|
||||
func @reduce_op_arg1_wrong_type(%shape : tensor<?xindex>, %init : index) {
|
||||
// expected-error@+1 {{argument 1 of ReduceOp body is expected to be of IndexType if the ReduceOp operates on an extent tensor}}
|
||||
%num_elements = shape.reduce(%shape, %init) : tensor<?xindex> -> index {
|
||||
^bb0(%index: index, %dim: f32, %lci: index):
|
||||
shape.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @reduce_op_init_type_mismatch(%shape : !shape.shape, %init : f32) {
|
||||
// expected-error@+1 {{type mismatch between argument 2 of ReduceOp body and initial value 0}}
|
||||
%num_elements = shape.reduce(%shape, %init) -> f32 {
|
||||
%num_elements = shape.reduce(%shape, %init) : !shape.shape -> f32 {
|
||||
^bb0(%index: index, %dim: !shape.size, %lci: !shape.size):
|
||||
shape.yield
|
||||
}
|
||||
|
@ -44,7 +54,7 @@ func @reduce_op_init_type_mismatch(%shape : !shape.shape, %init : f32) {
|
|||
|
||||
func @yield_op_args_num_mismatch(%shape : !shape.shape, %init : !shape.size) {
|
||||
// expected-error@+3 {{number of operands does not match number of results of its parent}}
|
||||
%num_elements = shape.reduce(%shape, %init) -> !shape.size {
|
||||
%num_elements = shape.reduce(%shape, %init) : !shape.shape -> !shape.size {
|
||||
^bb0(%index: index, %dim: !shape.size, %lci: !shape.size):
|
||||
shape.yield %dim, %dim : !shape.size, !shape.size
|
||||
}
|
||||
|
@ -54,7 +64,7 @@ func @yield_op_args_num_mismatch(%shape : !shape.shape, %init : !shape.size) {
|
|||
|
||||
func @yield_op_type_mismatch(%shape : !shape.shape, %init : !shape.size) {
|
||||
// expected-error@+4 {{types mismatch between yield op and its parent}}
|
||||
%num_elements = shape.reduce(%shape, %init) -> !shape.size {
|
||||
%num_elements = shape.reduce(%shape, %init) : !shape.shape -> !shape.size {
|
||||
^bb0(%index: index, %dim: !shape.size, %lci: !shape.size):
|
||||
%c0 = constant 1 : index
|
||||
shape.yield %c0 : index
|
||||
|
|
|
@ -6,15 +6,26 @@
|
|||
|
||||
// CHECK-LABEL: shape_num_elements
|
||||
func @shape_num_elements(%shape : !shape.shape) -> !shape.size {
|
||||
%init = shape.const_size 0
|
||||
%num_elements = shape.reduce(%shape, %init) -> !shape.size {
|
||||
^bb0(%index: index, %dim: !shape.size, %lci: !shape.size):
|
||||
%acc = shape.add %lci, %dim
|
||||
shape.yield %acc : !shape.size
|
||||
%init = shape.const_size 1
|
||||
%num_elements = shape.reduce(%shape, %init) : !shape.shape -> !shape.size {
|
||||
^bb0(%index : index, %extent : !shape.size, %acc : !shape.size):
|
||||
%acc_next = shape.mul %acc, %extent
|
||||
shape.yield %acc_next : !shape.size
|
||||
}
|
||||
return %num_elements : !shape.size
|
||||
}
|
||||
|
||||
// CHECK-LABEL: extent_tensor_num_elements
|
||||
func @extent_tensor_num_elements(%shape : tensor<?xindex>) -> index {
|
||||
%init = constant 1 : index
|
||||
%num_elements = shape.reduce(%shape, %init) : tensor<?xindex> -> index {
|
||||
^bb0(%index : index, %extent : index, %acc : index):
|
||||
%acc_next = muli %acc, %extent : index
|
||||
shape.yield %acc_next : index
|
||||
}
|
||||
return %num_elements : index
|
||||
}
|
||||
|
||||
func @test_shape_num_elements_unknown() {
|
||||
%0 = "shape.unknown_shape"() : () -> !shape.shape
|
||||
%1 = call @shape_num_elements(%0) : (!shape.shape) -> (!shape.size)
|
||||
|
|
|
@ -1,16 +1,16 @@
|
|||
// RUN: mlir-opt -shape-to-shape-lowering -split-input-file %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @num_elements_to_reduce(
|
||||
// CHECK-SAME: [[ARG:%.*]]: !shape.shape) -> [[SIZE_TY:!.*]] {
|
||||
// CHECK-SAME: [[ARG:%.*]]: !shape.shape) -> !shape.size {
|
||||
func @num_elements_to_reduce(%shape : !shape.shape) -> !shape.size {
|
||||
%num_elements = shape.num_elements %shape
|
||||
return %num_elements : !shape.size
|
||||
}
|
||||
// CHECK: [[C1:%.*]] = shape.const_size 1
|
||||
// CHECK: [[NUM_ELEMENTS:%.*]] = shape.reduce([[ARG]], [[C1]]) -> [[SIZE_TY]]
|
||||
// CHECK: ^bb0({{.*}}: index, [[DIM:%.*]]: [[SIZE_TY]], [[ACC:%.*]]: [[SIZE_TY]]
|
||||
// CHECK: [[NUM_ELEMENTS:%.*]] = shape.reduce([[ARG]], [[C1]]) : !shape.shape -> !shape.size
|
||||
// CHECK: ^bb0({{.*}}: index, [[DIM:%.*]]: !shape.size, [[ACC:%.*]]: !shape.size
|
||||
// CHECK: [[NEW_ACC:%.*]] = shape.mul [[DIM]], [[ACC]]
|
||||
// CHECK: shape.yield [[NEW_ACC]] : [[SIZE_TY]]
|
||||
// CHECK: shape.yield [[NEW_ACC]] : !shape.size
|
||||
// CHECK: }
|
||||
// CHECK: return [[NUM_ELEMENTS]] : [[SIZE_TY]]
|
||||
// CHECK: return [[NUM_ELEMENTS]] : !shape.size
|
||||
|
||||
|
|
Loading…
Reference in New Issue