forked from OSchip/llvm-project
[MLIR][Shape] Allow `get_extent` to operate on extent tensors and indices
Differential Revision: https://reviews.llvm.org/D84435
This commit is contained in:
parent
7f600da828
commit
5984d74139
|
@ -235,9 +235,10 @@ def Shape_GetExtentOp : Shape_Op<"get_extent", [NoSideEffect]> {
|
|||
an error then it returns an error size.
|
||||
}];
|
||||
let arguments = (ins Shape_ShapeOrExtentTensorType:$shape,
|
||||
Shape_SizeType:$dim);
|
||||
let results = (outs Shape_SizeType:$extent);
|
||||
let assemblyFormat = "$shape `,` $dim `:` type($shape) attr-dict";
|
||||
Shape_SizeOrIndexType:$dim);
|
||||
let results = (outs Shape_SizeOrIndexType:$extent);
|
||||
let assemblyFormat = "$shape `,` $dim `:` type($shape) `,` type($dim) `->` "
|
||||
"type($extent) attr-dict";
|
||||
|
||||
let builders = [
|
||||
// Builder that allows passing a constant dimension as a simple integer.
|
||||
|
@ -251,6 +252,7 @@ def Shape_GetExtentOp : Shape_Op<"get_extent", [NoSideEffect]> {
|
|||
}];
|
||||
|
||||
let hasFolder = 1;
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
}
|
||||
|
||||
def Shape_IndexToSizeOp : Shape_Op<"index_to_size", [NoSideEffect]> {
|
||||
|
|
|
@ -535,10 +535,30 @@ OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
|
|||
// GetExtentOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Optional<int64_t> GetExtentOp::getConstantDim() {
|
||||
if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>()) {
|
||||
return constSizeOp.value().getLimitedValue();
|
||||
static LogicalResult verify(GetExtentOp op) {
|
||||
Type shapeTy = op.shape().getType();
|
||||
Type dimTy = op.dim().getType();
|
||||
Type extentTy = op.extent().getType();
|
||||
bool errorPropagationPossible =
|
||||
shapeTy.isa<ShapeType>() || dimTy.isa<SizeType>();
|
||||
if (errorPropagationPossible) {
|
||||
if (!extentTy.isa<SizeType>())
|
||||
op.emitError()
|
||||
<< "if at least one of the operands can hold error values then the "
|
||||
"result must be of type `size` to propagate them";
|
||||
} else {
|
||||
if (extentTy.isa<SizeType>())
|
||||
op.emitError() << "if none of the operands can hold error values then "
|
||||
"the result must be of type `index`";
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
Optional<int64_t> GetExtentOp::getConstantDim() {
|
||||
if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>())
|
||||
return constSizeOp.value().getLimitedValue();
|
||||
if (auto constantOp = dim().getDefiningOp<ConstantOp>())
|
||||
return constantOp.value().cast<IntegerAttr>().getInt();
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
|
@ -558,8 +578,14 @@ void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
|
|||
int64_t dim) {
|
||||
auto loc = result.location;
|
||||
auto dimAttr = builder.getIndexAttr(dim);
|
||||
Value dimValue = builder.create<ConstSizeOp>(loc, dimAttr);
|
||||
build(builder, result, shape, dimValue);
|
||||
if (shape.getType().isa<ShapeType>()) {
|
||||
Value dim = builder.create<ConstSizeOp>(loc, dimAttr);
|
||||
build(builder, result, builder.getType<SizeType>(), shape, dim);
|
||||
} else {
|
||||
Value dim =
|
||||
builder.create<ConstantOp>(loc, builder.getIndexType(), dimAttr);
|
||||
build(builder, result, builder.getIndexType(), shape, dim);
|
||||
}
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -136,28 +136,25 @@ func @rank(%shape : tensor<?xindex>) -> index {
|
|||
// `shape_of` operation.
|
||||
// CHECK-LABEL: @get_extent_shape_of
|
||||
// CHECK-SAME: (%[[ARG:.*]]: tensor<2x3xf32>, %[[IDX:.*]]: index) -> index
|
||||
func @get_extent_shape_of(%arg : tensor<2x3xf32>, %idx : !shape.size)
|
||||
-> !shape.size {
|
||||
func @get_extent_shape_of(%arg : tensor<2x3xf32>, %idx : index) -> index {
|
||||
// CHECK: %[[RESULT:.*]] = dim %[[ARG]], %[[IDX]] : tensor<2x3xf32>
|
||||
// CHECK: return %[[RESULT]] : index
|
||||
%shape = shape.shape_of %arg : tensor<2x3xf32> -> tensor<?xindex>
|
||||
%result = shape.get_extent %shape, %idx : tensor<?xindex>
|
||||
return %result : !shape.size
|
||||
%result = shape.get_extent %shape, %idx : tensor<?xindex>, index -> index
|
||||
return %result : index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Express `get_extent` as `std.extract_element` when it relies directly on the
|
||||
// outcome of a `from_extent_tensor` operation.
|
||||
// Express `get_extent` as `std.extract_element`.
|
||||
// CHECK-LABEL: @get_extent_from_extent_tensor
|
||||
// CHECK-SAME: (%[[EXTENTS:.*]]: tensor<?xindex>, %[[IDX:.*]]: index) -> index
|
||||
func @get_extent_from_extent_tensor(%extents : tensor<?xindex>,
|
||||
%idx : !shape.size) -> !shape.size {
|
||||
func @get_extent_from_extent_tensor(%extents : tensor<?xindex>, %idx : index)
|
||||
-> index {
|
||||
// CHECK: %[[RESULT:.*]] = extract_element %[[EXTENTS]][%[[IDX]]] : tensor<?xindex>
|
||||
// CHECK: return %[[RESULT]] : index
|
||||
%shape = shape.from_extent_tensor %extents : tensor<?xindex>
|
||||
%result = shape.get_extent %shape, %idx : !shape.shape
|
||||
return %result : !shape.size
|
||||
%result = shape.get_extent %extents, %idx : tensor<?xindex>, index -> index
|
||||
return %result : index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
|
|
@ -235,13 +235,49 @@ func @nonfoldable_num_elements(%shape : !shape.shape) -> !shape.size {
|
|||
|
||||
// -----
|
||||
|
||||
// Basic folding.
|
||||
// CHECK-LABEL: func @basic
|
||||
func @basic() -> index {
|
||||
// CHECK: constant 2 : index
|
||||
%0 = shape.const_shape [0, 1, 2] : tensor<?xindex>
|
||||
%c2 = constant 2 : index
|
||||
%1 = shape.get_extent %0, %c2 : tensor<?xindex>, index -> index
|
||||
return %1 : index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Should not fold.
|
||||
// CHECK-LABEL: func @out_of_bounds
|
||||
func @out_of_bounds() -> index {
|
||||
// CHECK: shape.const_shape
|
||||
// CHECK: shape.get_extent
|
||||
%0 = shape.const_shape [0, 1, 2] : tensor<?xindex>
|
||||
%c3 = constant 3 : index
|
||||
%1 = shape.get_extent %0, %c3 : tensor<?xindex>, index -> index
|
||||
return %1 : index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Should not fold.
|
||||
// CHECK-LABEL: func @not_const
|
||||
func @not_const(%arg0: tensor<?xindex>) -> index {
|
||||
// CHECK: shape.get_extent
|
||||
%c3 = constant 3 : index
|
||||
%0 = shape.get_extent %arg0, %c3 : tensor<?xindex>, index -> index
|
||||
return %0 : index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Basic folding.
|
||||
// CHECK-LABEL: func @basic
|
||||
func @basic() -> !shape.size {
|
||||
// CHECK: shape.const_size 2
|
||||
%0 = shape.const_shape [0, 1, 2] : tensor<?xindex>
|
||||
%0 = shape.const_shape [0, 1, 2] : !shape.shape
|
||||
%c2 = shape.const_size 2
|
||||
%1 = shape.get_extent %0, %c2 : tensor<?xindex>
|
||||
%1 = shape.get_extent %0, %c2 : !shape.shape, !shape.size -> !shape.size
|
||||
return %1 : !shape.size
|
||||
}
|
||||
|
||||
|
@ -252,9 +288,9 @@ func @basic() -> !shape.size {
|
|||
func @out_of_bounds() -> !shape.size {
|
||||
// CHECK: shape.const_shape
|
||||
// CHECK: shape.get_extent
|
||||
%0 = shape.const_shape [0, 1, 2] : tensor<?xindex>
|
||||
%0 = shape.const_shape [0, 1, 2] : !shape.shape
|
||||
%c3 = shape.const_size 3
|
||||
%1 = shape.get_extent %0, %c3 : tensor<?xindex>
|
||||
%1 = shape.get_extent %0, %c3 : !shape.shape, !shape.size -> !shape.size
|
||||
return %1 : !shape.size
|
||||
}
|
||||
|
||||
|
@ -262,14 +298,13 @@ func @out_of_bounds() -> !shape.size {
|
|||
|
||||
// Should not fold.
|
||||
// CHECK-LABEL: func @not_const
|
||||
func @not_const(%arg0: tensor<?xindex>) -> !shape.size {
|
||||
func @not_const(%arg0 : !shape.shape) -> !shape.size {
|
||||
// CHECK: shape.get_extent
|
||||
%c3 = shape.const_size 3
|
||||
%0 = shape.get_extent %arg0, %c3 : tensor<?xindex>
|
||||
%0 = shape.get_extent %arg0, %c3 : !shape.shape, !shape.size -> !shape.size
|
||||
return %0 : !shape.size
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
// cstr_eq with non-constant but known equal shapes can be removed.
|
||||
// CHECK-LABEL: func @f
|
||||
|
|
|
@ -102,3 +102,21 @@ func @rank(%arg : !shape.shape) {
|
|||
%0 = shape.rank %arg : !shape.shape -> index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @get_extent_error_free(%arg : tensor<?xindex>) -> !shape.size {
|
||||
%c0 = constant 0 : index
|
||||
// expected-error@+1 {{if none of the operands can hold error values then the result must be of type `index`}}
|
||||
%result = shape.get_extent %arg, %c0 : tensor<?xindex>, index -> !shape.size
|
||||
return %result : !shape.size
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @get_extent_error_possible(%arg : tensor<?xindex>) -> index {
|
||||
%c0 = shape.const_size 0
|
||||
// expected-error@+1 {{if at least one of the operands can hold error values then the result must be of type `size` to propagate them}}
|
||||
%result = shape.get_extent %arg, %c0 : tensor<?xindex>, !shape.size -> index
|
||||
return %result : index
|
||||
}
|
||||
|
||||
|
|
|
@ -163,13 +163,20 @@ func @shape_eq_on_mixed(%a : tensor<?xindex>, %b : !shape.shape) -> i1 {
|
|||
|
||||
func @get_extent_on_shape(%arg : !shape.shape) -> !shape.size {
|
||||
%c0 = shape.const_size 0
|
||||
%result = shape.get_extent %arg, %c0 : !shape.shape
|
||||
%result = shape.get_extent %arg, %c0 :
|
||||
!shape.shape, !shape.size -> !shape.size
|
||||
return %result : !shape.size
|
||||
}
|
||||
|
||||
func @get_extent_on_extent_tensor(%arg : tensor<?xindex>) -> !shape.size {
|
||||
func @get_extent_on_extent_tensor(%arg : tensor<?xindex>) -> index {
|
||||
%c0 = constant 0 : index
|
||||
%result = shape.get_extent %arg, %c0 : tensor<?xindex>, index -> index
|
||||
return %result : index
|
||||
}
|
||||
|
||||
func @get_extent_on_mixed_operands(%arg : tensor<?xindex>) -> !shape.size {
|
||||
%c0 = shape.const_size 0
|
||||
%result = shape.get_extent %arg, %c0 : tensor<?xindex>
|
||||
%result = shape.get_extent %arg, %c0 : tensor<?xindex>, !shape.size -> !shape.size
|
||||
return %result : !shape.size
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue