[MLIR][Shape] Allow `get_extent` to operate on extent tensors and indices

Differential Revision: https://reviews.llvm.org/D84435
This commit is contained in:
Frederik Gossen 2020-07-24 11:12:39 +00:00
parent 7f600da828
commit 5984d74139
6 changed files with 114 additions and 29 deletions

View File

@ -235,9 +235,10 @@ def Shape_GetExtentOp : Shape_Op<"get_extent", [NoSideEffect]> {
an error then it returns an error size.
}];
let arguments = (ins Shape_ShapeOrExtentTensorType:$shape,
Shape_SizeType:$dim);
let results = (outs Shape_SizeType:$extent);
let assemblyFormat = "$shape `,` $dim `:` type($shape) attr-dict";
Shape_SizeOrIndexType:$dim);
let results = (outs Shape_SizeOrIndexType:$extent);
let assemblyFormat = "$shape `,` $dim `:` type($shape) `,` type($dim) `->` "
"type($extent) attr-dict";
let builders = [
// Builder that allows passing a constant dimension as a simple integer.
@ -251,6 +252,7 @@ def Shape_GetExtentOp : Shape_Op<"get_extent", [NoSideEffect]> {
}];
let hasFolder = 1;
let verifier = [{ return ::verify(*this); }];
}
def Shape_IndexToSizeOp : Shape_Op<"index_to_size", [NoSideEffect]> {

View File

@ -535,10 +535,30 @@ OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
// GetExtentOp
//===----------------------------------------------------------------------===//
Optional<int64_t> GetExtentOp::getConstantDim() {
if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>()) {
return constSizeOp.value().getLimitedValue();
static LogicalResult verify(GetExtentOp op) {
Type shapeTy = op.shape().getType();
Type dimTy = op.dim().getType();
Type extentTy = op.extent().getType();
bool errorPropagationPossible =
shapeTy.isa<ShapeType>() || dimTy.isa<SizeType>();
if (errorPropagationPossible) {
if (!extentTy.isa<SizeType>())
op.emitError()
<< "if at least one of the operands can hold error values then the "
"result must be of type `size` to propagate them";
} else {
if (extentTy.isa<SizeType>())
op.emitError() << "if none of the operands can hold error values then "
"the result must be of type `index`";
}
return success();
}
Optional<int64_t> GetExtentOp::getConstantDim() {
if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>())
return constSizeOp.value().getLimitedValue();
if (auto constantOp = dim().getDefiningOp<ConstantOp>())
return constantOp.value().cast<IntegerAttr>().getInt();
return llvm::None;
}
@ -558,8 +578,14 @@ void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
int64_t dim) {
auto loc = result.location;
auto dimAttr = builder.getIndexAttr(dim);
Value dimValue = builder.create<ConstSizeOp>(loc, dimAttr);
build(builder, result, shape, dimValue);
if (shape.getType().isa<ShapeType>()) {
Value dim = builder.create<ConstSizeOp>(loc, dimAttr);
build(builder, result, builder.getType<SizeType>(), shape, dim);
} else {
Value dim =
builder.create<ConstantOp>(loc, builder.getIndexType(), dimAttr);
build(builder, result, builder.getIndexType(), shape, dim);
}
}
//===----------------------------------------------------------------------===//

View File

@ -136,28 +136,25 @@ func @rank(%shape : tensor<?xindex>) -> index {
// `shape_of` operation.
// CHECK-LABEL: @get_extent_shape_of
// CHECK-SAME: (%[[ARG:.*]]: tensor<2x3xf32>, %[[IDX:.*]]: index) -> index
func @get_extent_shape_of(%arg : tensor<2x3xf32>, %idx : !shape.size)
-> !shape.size {
func @get_extent_shape_of(%arg : tensor<2x3xf32>, %idx : index) -> index {
// CHECK: %[[RESULT:.*]] = dim %[[ARG]], %[[IDX]] : tensor<2x3xf32>
// CHECK: return %[[RESULT]] : index
%shape = shape.shape_of %arg : tensor<2x3xf32> -> tensor<?xindex>
%result = shape.get_extent %shape, %idx : tensor<?xindex>
return %result : !shape.size
%result = shape.get_extent %shape, %idx : tensor<?xindex>, index -> index
return %result : index
}
// -----
// Express `get_extent` as `std.extract_element` when it relies directly on the
// outcome of a `from_extent_tensor` operation.
// Express `get_extent` as `std.extract_element`.
// CHECK-LABEL: @get_extent_from_extent_tensor
// CHECK-SAME: (%[[EXTENTS:.*]]: tensor<?xindex>, %[[IDX:.*]]: index) -> index
func @get_extent_from_extent_tensor(%extents : tensor<?xindex>,
%idx : !shape.size) -> !shape.size {
func @get_extent_from_extent_tensor(%extents : tensor<?xindex>, %idx : index)
-> index {
// CHECK: %[[RESULT:.*]] = extract_element %[[EXTENTS]][%[[IDX]]] : tensor<?xindex>
// CHECK: return %[[RESULT]] : index
%shape = shape.from_extent_tensor %extents : tensor<?xindex>
%result = shape.get_extent %shape, %idx : !shape.shape
return %result : !shape.size
%result = shape.get_extent %extents, %idx : tensor<?xindex>, index -> index
return %result : index
}
// -----

View File

@ -235,13 +235,49 @@ func @nonfoldable_num_elements(%shape : !shape.shape) -> !shape.size {
// -----
// Basic folding.
// CHECK-LABEL: func @basic
func @basic() -> index {
// CHECK: constant 2 : index
%0 = shape.const_shape [0, 1, 2] : tensor<?xindex>
%c2 = constant 2 : index
%1 = shape.get_extent %0, %c2 : tensor<?xindex>, index -> index
return %1 : index
}
// -----
// Should not fold.
// CHECK-LABEL: func @out_of_bounds
func @out_of_bounds() -> index {
// CHECK: shape.const_shape
// CHECK: shape.get_extent
%0 = shape.const_shape [0, 1, 2] : tensor<?xindex>
%c3 = constant 3 : index
%1 = shape.get_extent %0, %c3 : tensor<?xindex>, index -> index
return %1 : index
}
// -----
// Should not fold.
// CHECK-LABEL: func @not_const
func @not_const(%arg0: tensor<?xindex>) -> index {
// CHECK: shape.get_extent
%c3 = constant 3 : index
%0 = shape.get_extent %arg0, %c3 : tensor<?xindex>, index -> index
return %0 : index
}
// -----
// Basic folding.
// CHECK-LABEL: func @basic
func @basic() -> !shape.size {
// CHECK: shape.const_size 2
%0 = shape.const_shape [0, 1, 2] : tensor<?xindex>
%0 = shape.const_shape [0, 1, 2] : !shape.shape
%c2 = shape.const_size 2
%1 = shape.get_extent %0, %c2 : tensor<?xindex>
%1 = shape.get_extent %0, %c2 : !shape.shape, !shape.size -> !shape.size
return %1 : !shape.size
}
@ -252,9 +288,9 @@ func @basic() -> !shape.size {
func @out_of_bounds() -> !shape.size {
// CHECK: shape.const_shape
// CHECK: shape.get_extent
%0 = shape.const_shape [0, 1, 2] : tensor<?xindex>
%0 = shape.const_shape [0, 1, 2] : !shape.shape
%c3 = shape.const_size 3
%1 = shape.get_extent %0, %c3 : tensor<?xindex>
%1 = shape.get_extent %0, %c3 : !shape.shape, !shape.size -> !shape.size
return %1 : !shape.size
}
@ -262,14 +298,13 @@ func @out_of_bounds() -> !shape.size {
// Should not fold.
// CHECK-LABEL: func @not_const
func @not_const(%arg0: tensor<?xindex>) -> !shape.size {
func @not_const(%arg0 : !shape.shape) -> !shape.size {
// CHECK: shape.get_extent
%c3 = shape.const_size 3
%0 = shape.get_extent %arg0, %c3 : tensor<?xindex>
%0 = shape.get_extent %arg0, %c3 : !shape.shape, !shape.size -> !shape.size
return %0 : !shape.size
}
// -----
// cstr_eq with non-constant but known equal shapes can be removed.
// CHECK-LABEL: func @f

View File

@ -102,3 +102,21 @@ func @rank(%arg : !shape.shape) {
%0 = shape.rank %arg : !shape.shape -> index
}
// -----
func @get_extent_error_free(%arg : tensor<?xindex>) -> !shape.size {
%c0 = constant 0 : index
// expected-error@+1 {{if none of the operands can hold error values then the result must be of type `index`}}
%result = shape.get_extent %arg, %c0 : tensor<?xindex>, index -> !shape.size
return %result : !shape.size
}
// -----
func @get_extent_error_possible(%arg : tensor<?xindex>) -> index {
%c0 = shape.const_size 0
// expected-error@+1 {{if at least one of the operands can hold error values then the result must be of type `size` to propagate them}}
%result = shape.get_extent %arg, %c0 : tensor<?xindex>, !shape.size -> index
return %result : index
}

View File

@ -163,13 +163,20 @@ func @shape_eq_on_mixed(%a : tensor<?xindex>, %b : !shape.shape) -> i1 {
func @get_extent_on_shape(%arg : !shape.shape) -> !shape.size {
%c0 = shape.const_size 0
%result = shape.get_extent %arg, %c0 : !shape.shape
%result = shape.get_extent %arg, %c0 :
!shape.shape, !shape.size -> !shape.size
return %result : !shape.size
}
func @get_extent_on_extent_tensor(%arg : tensor<?xindex>) -> !shape.size {
func @get_extent_on_extent_tensor(%arg : tensor<?xindex>) -> index {
%c0 = constant 0 : index
%result = shape.get_extent %arg, %c0 : tensor<?xindex>, index -> index
return %result : index
}
func @get_extent_on_mixed_operands(%arg : tensor<?xindex>) -> !shape.size {
%c0 = shape.const_size 0
%result = shape.get_extent %arg, %c0 : tensor<?xindex>
%result = shape.get_extent %arg, %c0 : tensor<?xindex>, !shape.size -> !shape.size
return %result : !shape.size
}