forked from OSchip/llvm-project
[MLIR][Shape] Allow `get_extent` to operate on extent tensors and indices
Differential Revision: https://reviews.llvm.org/D84435
This commit is contained in:
parent
7f600da828
commit
5984d74139
|
@ -235,9 +235,10 @@ def Shape_GetExtentOp : Shape_Op<"get_extent", [NoSideEffect]> {
|
||||||
an error then it returns an error size.
|
an error then it returns an error size.
|
||||||
}];
|
}];
|
||||||
let arguments = (ins Shape_ShapeOrExtentTensorType:$shape,
|
let arguments = (ins Shape_ShapeOrExtentTensorType:$shape,
|
||||||
Shape_SizeType:$dim);
|
Shape_SizeOrIndexType:$dim);
|
||||||
let results = (outs Shape_SizeType:$extent);
|
let results = (outs Shape_SizeOrIndexType:$extent);
|
||||||
let assemblyFormat = "$shape `,` $dim `:` type($shape) attr-dict";
|
let assemblyFormat = "$shape `,` $dim `:` type($shape) `,` type($dim) `->` "
|
||||||
|
"type($extent) attr-dict";
|
||||||
|
|
||||||
let builders = [
|
let builders = [
|
||||||
// Builder that allows passing a constant dimension as a simple integer.
|
// Builder that allows passing a constant dimension as a simple integer.
|
||||||
|
@ -251,6 +252,7 @@ def Shape_GetExtentOp : Shape_Op<"get_extent", [NoSideEffect]> {
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
|
let verifier = [{ return ::verify(*this); }];
|
||||||
}
|
}
|
||||||
|
|
||||||
def Shape_IndexToSizeOp : Shape_Op<"index_to_size", [NoSideEffect]> {
|
def Shape_IndexToSizeOp : Shape_Op<"index_to_size", [NoSideEffect]> {
|
||||||
|
|
|
@ -535,10 +535,30 @@ OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
|
||||||
// GetExtentOp
|
// GetExtentOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
Optional<int64_t> GetExtentOp::getConstantDim() {
|
static LogicalResult verify(GetExtentOp op) {
|
||||||
if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>()) {
|
Type shapeTy = op.shape().getType();
|
||||||
return constSizeOp.value().getLimitedValue();
|
Type dimTy = op.dim().getType();
|
||||||
|
Type extentTy = op.extent().getType();
|
||||||
|
bool errorPropagationPossible =
|
||||||
|
shapeTy.isa<ShapeType>() || dimTy.isa<SizeType>();
|
||||||
|
if (errorPropagationPossible) {
|
||||||
|
if (!extentTy.isa<SizeType>())
|
||||||
|
op.emitError()
|
||||||
|
<< "if at least one of the operands can hold error values then the "
|
||||||
|
"result must be of type `size` to propagate them";
|
||||||
|
} else {
|
||||||
|
if (extentTy.isa<SizeType>())
|
||||||
|
op.emitError() << "if none of the operands can hold error values then "
|
||||||
|
"the result must be of type `index`";
|
||||||
}
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
Optional<int64_t> GetExtentOp::getConstantDim() {
|
||||||
|
if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>())
|
||||||
|
return constSizeOp.value().getLimitedValue();
|
||||||
|
if (auto constantOp = dim().getDefiningOp<ConstantOp>())
|
||||||
|
return constantOp.value().cast<IntegerAttr>().getInt();
|
||||||
return llvm::None;
|
return llvm::None;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -558,8 +578,14 @@ void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
|
||||||
int64_t dim) {
|
int64_t dim) {
|
||||||
auto loc = result.location;
|
auto loc = result.location;
|
||||||
auto dimAttr = builder.getIndexAttr(dim);
|
auto dimAttr = builder.getIndexAttr(dim);
|
||||||
Value dimValue = builder.create<ConstSizeOp>(loc, dimAttr);
|
if (shape.getType().isa<ShapeType>()) {
|
||||||
build(builder, result, shape, dimValue);
|
Value dim = builder.create<ConstSizeOp>(loc, dimAttr);
|
||||||
|
build(builder, result, builder.getType<SizeType>(), shape, dim);
|
||||||
|
} else {
|
||||||
|
Value dim =
|
||||||
|
builder.create<ConstantOp>(loc, builder.getIndexType(), dimAttr);
|
||||||
|
build(builder, result, builder.getIndexType(), shape, dim);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -136,28 +136,25 @@ func @rank(%shape : tensor<?xindex>) -> index {
|
||||||
// `shape_of` operation.
|
// `shape_of` operation.
|
||||||
// CHECK-LABEL: @get_extent_shape_of
|
// CHECK-LABEL: @get_extent_shape_of
|
||||||
// CHECK-SAME: (%[[ARG:.*]]: tensor<2x3xf32>, %[[IDX:.*]]: index) -> index
|
// CHECK-SAME: (%[[ARG:.*]]: tensor<2x3xf32>, %[[IDX:.*]]: index) -> index
|
||||||
func @get_extent_shape_of(%arg : tensor<2x3xf32>, %idx : !shape.size)
|
func @get_extent_shape_of(%arg : tensor<2x3xf32>, %idx : index) -> index {
|
||||||
-> !shape.size {
|
|
||||||
// CHECK: %[[RESULT:.*]] = dim %[[ARG]], %[[IDX]] : tensor<2x3xf32>
|
// CHECK: %[[RESULT:.*]] = dim %[[ARG]], %[[IDX]] : tensor<2x3xf32>
|
||||||
// CHECK: return %[[RESULT]] : index
|
// CHECK: return %[[RESULT]] : index
|
||||||
%shape = shape.shape_of %arg : tensor<2x3xf32> -> tensor<?xindex>
|
%shape = shape.shape_of %arg : tensor<2x3xf32> -> tensor<?xindex>
|
||||||
%result = shape.get_extent %shape, %idx : tensor<?xindex>
|
%result = shape.get_extent %shape, %idx : tensor<?xindex>, index -> index
|
||||||
return %result : !shape.size
|
return %result : index
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// Express `get_extent` as `std.extract_element` when it relies directly on the
|
// Express `get_extent` as `std.extract_element`.
|
||||||
// outcome of a `from_extent_tensor` operation.
|
|
||||||
// CHECK-LABEL: @get_extent_from_extent_tensor
|
// CHECK-LABEL: @get_extent_from_extent_tensor
|
||||||
// CHECK-SAME: (%[[EXTENTS:.*]]: tensor<?xindex>, %[[IDX:.*]]: index) -> index
|
// CHECK-SAME: (%[[EXTENTS:.*]]: tensor<?xindex>, %[[IDX:.*]]: index) -> index
|
||||||
func @get_extent_from_extent_tensor(%extents : tensor<?xindex>,
|
func @get_extent_from_extent_tensor(%extents : tensor<?xindex>, %idx : index)
|
||||||
%idx : !shape.size) -> !shape.size {
|
-> index {
|
||||||
// CHECK: %[[RESULT:.*]] = extract_element %[[EXTENTS]][%[[IDX]]] : tensor<?xindex>
|
// CHECK: %[[RESULT:.*]] = extract_element %[[EXTENTS]][%[[IDX]]] : tensor<?xindex>
|
||||||
// CHECK: return %[[RESULT]] : index
|
// CHECK: return %[[RESULT]] : index
|
||||||
%shape = shape.from_extent_tensor %extents : tensor<?xindex>
|
%result = shape.get_extent %extents, %idx : tensor<?xindex>, index -> index
|
||||||
%result = shape.get_extent %shape, %idx : !shape.shape
|
return %result : index
|
||||||
return %result : !shape.size
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
|
@ -235,13 +235,49 @@ func @nonfoldable_num_elements(%shape : !shape.shape) -> !shape.size {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// Basic folding.
|
||||||
|
// CHECK-LABEL: func @basic
|
||||||
|
func @basic() -> index {
|
||||||
|
// CHECK: constant 2 : index
|
||||||
|
%0 = shape.const_shape [0, 1, 2] : tensor<?xindex>
|
||||||
|
%c2 = constant 2 : index
|
||||||
|
%1 = shape.get_extent %0, %c2 : tensor<?xindex>, index -> index
|
||||||
|
return %1 : index
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Should not fold.
|
||||||
|
// CHECK-LABEL: func @out_of_bounds
|
||||||
|
func @out_of_bounds() -> index {
|
||||||
|
// CHECK: shape.const_shape
|
||||||
|
// CHECK: shape.get_extent
|
||||||
|
%0 = shape.const_shape [0, 1, 2] : tensor<?xindex>
|
||||||
|
%c3 = constant 3 : index
|
||||||
|
%1 = shape.get_extent %0, %c3 : tensor<?xindex>, index -> index
|
||||||
|
return %1 : index
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Should not fold.
|
||||||
|
// CHECK-LABEL: func @not_const
|
||||||
|
func @not_const(%arg0: tensor<?xindex>) -> index {
|
||||||
|
// CHECK: shape.get_extent
|
||||||
|
%c3 = constant 3 : index
|
||||||
|
%0 = shape.get_extent %arg0, %c3 : tensor<?xindex>, index -> index
|
||||||
|
return %0 : index
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// Basic folding.
|
// Basic folding.
|
||||||
// CHECK-LABEL: func @basic
|
// CHECK-LABEL: func @basic
|
||||||
func @basic() -> !shape.size {
|
func @basic() -> !shape.size {
|
||||||
// CHECK: shape.const_size 2
|
// CHECK: shape.const_size 2
|
||||||
%0 = shape.const_shape [0, 1, 2] : tensor<?xindex>
|
%0 = shape.const_shape [0, 1, 2] : !shape.shape
|
||||||
%c2 = shape.const_size 2
|
%c2 = shape.const_size 2
|
||||||
%1 = shape.get_extent %0, %c2 : tensor<?xindex>
|
%1 = shape.get_extent %0, %c2 : !shape.shape, !shape.size -> !shape.size
|
||||||
return %1 : !shape.size
|
return %1 : !shape.size
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -252,9 +288,9 @@ func @basic() -> !shape.size {
|
||||||
func @out_of_bounds() -> !shape.size {
|
func @out_of_bounds() -> !shape.size {
|
||||||
// CHECK: shape.const_shape
|
// CHECK: shape.const_shape
|
||||||
// CHECK: shape.get_extent
|
// CHECK: shape.get_extent
|
||||||
%0 = shape.const_shape [0, 1, 2] : tensor<?xindex>
|
%0 = shape.const_shape [0, 1, 2] : !shape.shape
|
||||||
%c3 = shape.const_size 3
|
%c3 = shape.const_size 3
|
||||||
%1 = shape.get_extent %0, %c3 : tensor<?xindex>
|
%1 = shape.get_extent %0, %c3 : !shape.shape, !shape.size -> !shape.size
|
||||||
return %1 : !shape.size
|
return %1 : !shape.size
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -262,14 +298,13 @@ func @out_of_bounds() -> !shape.size {
|
||||||
|
|
||||||
// Should not fold.
|
// Should not fold.
|
||||||
// CHECK-LABEL: func @not_const
|
// CHECK-LABEL: func @not_const
|
||||||
func @not_const(%arg0: tensor<?xindex>) -> !shape.size {
|
func @not_const(%arg0 : !shape.shape) -> !shape.size {
|
||||||
// CHECK: shape.get_extent
|
// CHECK: shape.get_extent
|
||||||
%c3 = shape.const_size 3
|
%c3 = shape.const_size 3
|
||||||
%0 = shape.get_extent %arg0, %c3 : tensor<?xindex>
|
%0 = shape.get_extent %arg0, %c3 : !shape.shape, !shape.size -> !shape.size
|
||||||
return %0 : !shape.size
|
return %0 : !shape.size
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
// cstr_eq with non-constant but known equal shapes can be removed.
|
// cstr_eq with non-constant but known equal shapes can be removed.
|
||||||
// CHECK-LABEL: func @f
|
// CHECK-LABEL: func @f
|
||||||
|
|
|
@ -102,3 +102,21 @@ func @rank(%arg : !shape.shape) {
|
||||||
%0 = shape.rank %arg : !shape.shape -> index
|
%0 = shape.rank %arg : !shape.shape -> index
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @get_extent_error_free(%arg : tensor<?xindex>) -> !shape.size {
|
||||||
|
%c0 = constant 0 : index
|
||||||
|
// expected-error@+1 {{if none of the operands can hold error values then the result must be of type `index`}}
|
||||||
|
%result = shape.get_extent %arg, %c0 : tensor<?xindex>, index -> !shape.size
|
||||||
|
return %result : !shape.size
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @get_extent_error_possible(%arg : tensor<?xindex>) -> index {
|
||||||
|
%c0 = shape.const_size 0
|
||||||
|
// expected-error@+1 {{if at least one of the operands can hold error values then the result must be of type `size` to propagate them}}
|
||||||
|
%result = shape.get_extent %arg, %c0 : tensor<?xindex>, !shape.size -> index
|
||||||
|
return %result : index
|
||||||
|
}
|
||||||
|
|
||||||
|
|
|
@ -163,13 +163,20 @@ func @shape_eq_on_mixed(%a : tensor<?xindex>, %b : !shape.shape) -> i1 {
|
||||||
|
|
||||||
func @get_extent_on_shape(%arg : !shape.shape) -> !shape.size {
|
func @get_extent_on_shape(%arg : !shape.shape) -> !shape.size {
|
||||||
%c0 = shape.const_size 0
|
%c0 = shape.const_size 0
|
||||||
%result = shape.get_extent %arg, %c0 : !shape.shape
|
%result = shape.get_extent %arg, %c0 :
|
||||||
|
!shape.shape, !shape.size -> !shape.size
|
||||||
return %result : !shape.size
|
return %result : !shape.size
|
||||||
}
|
}
|
||||||
|
|
||||||
func @get_extent_on_extent_tensor(%arg : tensor<?xindex>) -> !shape.size {
|
func @get_extent_on_extent_tensor(%arg : tensor<?xindex>) -> index {
|
||||||
|
%c0 = constant 0 : index
|
||||||
|
%result = shape.get_extent %arg, %c0 : tensor<?xindex>, index -> index
|
||||||
|
return %result : index
|
||||||
|
}
|
||||||
|
|
||||||
|
func @get_extent_on_mixed_operands(%arg : tensor<?xindex>) -> !shape.size {
|
||||||
%c0 = shape.const_size 0
|
%c0 = shape.const_size 0
|
||||||
%result = shape.get_extent %arg, %c0 : tensor<?xindex>
|
%result = shape.get_extent %arg, %c0 : tensor<?xindex>, !shape.size -> !shape.size
|
||||||
return %result : !shape.size
|
return %result : !shape.size
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue