[mlir] Fix some edge cases around 0-element TensorFromElementsOp

This introduces a builder for the more general case that supports zero
elements (where the element type can't be inferred from the ValueRange,
since it might be empty).

Also, fix up some cases in ShapeToStandard lowering that hit this. It
happens very easily when dealing with shapes of 0-D tensors.

The SameOperandsAndResultElementType is redundant with the new
TypesMatchWith and prevented having zero elements.

Differential Revision: https://reviews.llvm.org/D87492
This commit is contained in:
Sean Silva 2020-09-10 22:04:58 -07:00
parent aeb4314391
commit 84a6da67e6
5 changed files with 44 additions and 7 deletions

View File

@ -1613,7 +1613,6 @@ def ExtractElementOp : Std_Op<"extract_element",
def TensorFromElementsOp : Std_Op<"tensor_from_elements", [
NoSideEffect,
SameOperandsAndResultElementType,
TypesMatchWith<"operand types match result element type",
"result", "elements", "SmallVector<Type, 2>("
"$_self.cast<ShapedType>().getDimSize(0), "
@ -1638,7 +1637,11 @@ def TensorFromElementsOp : Std_Op<"tensor_from_elements", [
// This op is fully verified by its traits.
let verifier = ?;
let skipDefaultBuilders = 1;
let builders = [
OpBuilder<"OpBuilder &b, OperationState &result, Type elementType,"
"ValueRange elements">,
// Special case builder for when `elements` has size >=1.
OpBuilder<"OpBuilder &b, OperationState &result, ValueRange elements">
];

View File

@ -182,8 +182,9 @@ LogicalResult ConstShapeOpConverter::matchAndRewrite(
extentOperands.push_back(
rewriter.create<ConstantIndexOp>(loc, extent.getLimitedValue()));
}
Value tensor = rewriter.create<TensorFromElementsOp>(loc, extentOperands);
Type indexTy = rewriter.getIndexType();
Value tensor =
rewriter.create<TensorFromElementsOp>(loc, indexTy, extentOperands);
Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
rewriter.replaceOpWithNewOp<TensorCastOp>(op, tensor, resultTy);
return success();
@ -444,8 +445,8 @@ LogicalResult ShapeOfOpConversion::matchAndRewrite(
}
// Materialize extent tensor.
Value staticExtentTensor =
rewriter.create<TensorFromElementsOp>(loc, extentValues);
Value staticExtentTensor = rewriter.create<TensorFromElementsOp>(
loc, rewriter.getIndexType(), extentValues);
rewriter.replaceOpWithNewOp<TensorCastOp>(op, staticExtentTensor,
op.getType());
return success();

View File

@ -1756,12 +1756,18 @@ OpFoldResult ExtractElementOp::fold(ArrayRef<Attribute> operands) {
// TensorFromElementsOp
//===----------------------------------------------------------------------===//
void TensorFromElementsOp::build(OpBuilder &builder, OperationState &result,
Type elementType, ValueRange elements) {
Type resultTy = RankedTensorType::get({static_cast<int64_t>(elements.size())},
elementType);
result.addOperands(elements);
result.addTypes(resultTy);
}
void TensorFromElementsOp::build(OpBuilder &builder, OperationState &result,
ValueRange elements) {
assert(!elements.empty() && "expected at least one element");
Type resultTy = RankedTensorType::get({static_cast<int64_t>(elements.size())},
elements.front().getType());
build(builder, result, resultTy, elements);
build(builder, result, elements.front().getType(), elements);
}
namespace {

View File

@ -103,6 +103,19 @@ func @const_shape() -> tensor<?xindex> {
// -----
// Lower `const_shape` in the case of rank 0.
// CHECK-LABEL: func @const_shape_zero_elements
// CHECK-SAME: () -> tensor<?xindex>
func @const_shape_zero_elements() -> tensor<?xindex> {
// CHECK: %[[TENSOR:.*]] = tensor_from_elements : tensor<0xindex>
// CHECK: %[[RESULT:.*]] = tensor_cast %[[TENSOR]] : tensor<0xindex> to tensor<?xindex>
// CHECK: return %[[RESULT]] : tensor<?xindex>
%shape = shape.const_shape [] : tensor<?xindex>
return %shape : tensor<?xindex>
}
// -----
// Lower `any` to its first operand.
// CHECK-LABEL: @any_of_three
// CHECK-SAME: (%[[A:.*]]: tensor<?xindex>, %[[B:.*]]: tensor<?xindex>, %[[C:.*]]: tensor<?xindex>) -> tensor<?xindex>
@ -227,6 +240,17 @@ func @shape_of_stat(%arg : tensor<1x2x3xf32>) {
// -----
// Lower `shape_of` for 0-D tensor.
// CHECK-LABEL: @shape_of_zero_d
// CHECK-SAME: (%[[ARG:.*]]: tensor<f32>)
func @shape_of_zero_d(%arg : tensor<f32>) {
// CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements : tensor<0xindex>
%shape = shape.shape_of %arg : tensor<f32> -> tensor<?xindex>
return
}
// -----
// Lower `shape_of` for dynamically shaped tensor.
// CHECK-LABEL: @shape_of_dyn
// CHECK-SAME: (%[[ARG:.*]]: tensor<1x5x?xf32>)

View File

@ -673,6 +673,9 @@ func @tensor_from_elements() {
// CHECK: %2 = tensor_from_elements [[C0_F32]] : tensor<1xf32>
%2 = tensor_from_elements %c0_f32 : tensor<1xf32>
// CHECK: tensor_from_elements : tensor<0xindex>
%3 = tensor_from_elements : tensor<0xindex>
return
}