forked from OSchip/llvm-project
[mlir] Fix some edge cases around 0-element TensorFromElementsOp
This introduces a builder for the more general case that supports zero elements (where the element type can't be inferred from the ValueRange, since it might be empty). Also, fix up some cases in ShapeToStandard lowering that hit this. It happens very easily when dealing with shapes of 0-D tensors. The SameOperandsAndResultElementType is redundant with the new TypesMatchWith and prevented having zero elements. Differential Revision: https://reviews.llvm.org/D87492
This commit is contained in:
parent
aeb4314391
commit
84a6da67e6
|
@ -1613,7 +1613,6 @@ def ExtractElementOp : Std_Op<"extract_element",
|
|||
|
||||
def TensorFromElementsOp : Std_Op<"tensor_from_elements", [
|
||||
NoSideEffect,
|
||||
SameOperandsAndResultElementType,
|
||||
TypesMatchWith<"operand types match result element type",
|
||||
"result", "elements", "SmallVector<Type, 2>("
|
||||
"$_self.cast<ShapedType>().getDimSize(0), "
|
||||
|
@ -1638,7 +1637,11 @@ def TensorFromElementsOp : Std_Op<"tensor_from_elements", [
|
|||
// This op is fully verified by its traits.
|
||||
let verifier = ?;
|
||||
|
||||
let skipDefaultBuilders = 1;
|
||||
let builders = [
|
||||
OpBuilder<"OpBuilder &b, OperationState &result, Type elementType,"
|
||||
"ValueRange elements">,
|
||||
// Special case builder for when `elements` has size >=1.
|
||||
OpBuilder<"OpBuilder &b, OperationState &result, ValueRange elements">
|
||||
];
|
||||
|
||||
|
|
|
@ -182,8 +182,9 @@ LogicalResult ConstShapeOpConverter::matchAndRewrite(
|
|||
extentOperands.push_back(
|
||||
rewriter.create<ConstantIndexOp>(loc, extent.getLimitedValue()));
|
||||
}
|
||||
Value tensor = rewriter.create<TensorFromElementsOp>(loc, extentOperands);
|
||||
Type indexTy = rewriter.getIndexType();
|
||||
Value tensor =
|
||||
rewriter.create<TensorFromElementsOp>(loc, indexTy, extentOperands);
|
||||
Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
|
||||
rewriter.replaceOpWithNewOp<TensorCastOp>(op, tensor, resultTy);
|
||||
return success();
|
||||
|
@ -444,8 +445,8 @@ LogicalResult ShapeOfOpConversion::matchAndRewrite(
|
|||
}
|
||||
|
||||
// Materialize extent tensor.
|
||||
Value staticExtentTensor =
|
||||
rewriter.create<TensorFromElementsOp>(loc, extentValues);
|
||||
Value staticExtentTensor = rewriter.create<TensorFromElementsOp>(
|
||||
loc, rewriter.getIndexType(), extentValues);
|
||||
rewriter.replaceOpWithNewOp<TensorCastOp>(op, staticExtentTensor,
|
||||
op.getType());
|
||||
return success();
|
||||
|
|
|
@ -1756,12 +1756,18 @@ OpFoldResult ExtractElementOp::fold(ArrayRef<Attribute> operands) {
|
|||
// TensorFromElementsOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void TensorFromElementsOp::build(OpBuilder &builder, OperationState &result,
|
||||
Type elementType, ValueRange elements) {
|
||||
Type resultTy = RankedTensorType::get({static_cast<int64_t>(elements.size())},
|
||||
elementType);
|
||||
result.addOperands(elements);
|
||||
result.addTypes(resultTy);
|
||||
}
|
||||
|
||||
void TensorFromElementsOp::build(OpBuilder &builder, OperationState &result,
|
||||
ValueRange elements) {
|
||||
assert(!elements.empty() && "expected at least one element");
|
||||
Type resultTy = RankedTensorType::get({static_cast<int64_t>(elements.size())},
|
||||
elements.front().getType());
|
||||
build(builder, result, resultTy, elements);
|
||||
build(builder, result, elements.front().getType(), elements);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
|
|
@ -103,6 +103,19 @@ func @const_shape() -> tensor<?xindex> {
|
|||
|
||||
// -----
|
||||
|
||||
// Lower `const_shape` in the case of rank 0.
|
||||
// CHECK-LABEL: func @const_shape_zero_elements
|
||||
// CHECK-SAME: () -> tensor<?xindex>
|
||||
func @const_shape_zero_elements() -> tensor<?xindex> {
|
||||
// CHECK: %[[TENSOR:.*]] = tensor_from_elements : tensor<0xindex>
|
||||
// CHECK: %[[RESULT:.*]] = tensor_cast %[[TENSOR]] : tensor<0xindex> to tensor<?xindex>
|
||||
// CHECK: return %[[RESULT]] : tensor<?xindex>
|
||||
%shape = shape.const_shape [] : tensor<?xindex>
|
||||
return %shape : tensor<?xindex>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Lower `any` to its first operand.
|
||||
// CHECK-LABEL: @any_of_three
|
||||
// CHECK-SAME: (%[[A:.*]]: tensor<?xindex>, %[[B:.*]]: tensor<?xindex>, %[[C:.*]]: tensor<?xindex>) -> tensor<?xindex>
|
||||
|
@ -227,6 +240,17 @@ func @shape_of_stat(%arg : tensor<1x2x3xf32>) {
|
|||
|
||||
// -----
|
||||
|
||||
// Lower `shape_of` for 0-D tensor.
|
||||
// CHECK-LABEL: @shape_of_zero_d
|
||||
// CHECK-SAME: (%[[ARG:.*]]: tensor<f32>)
|
||||
func @shape_of_zero_d(%arg : tensor<f32>) {
|
||||
// CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements : tensor<0xindex>
|
||||
%shape = shape.shape_of %arg : tensor<f32> -> tensor<?xindex>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Lower `shape_of` for dynamically shaped tensor.
|
||||
// CHECK-LABEL: @shape_of_dyn
|
||||
// CHECK-SAME: (%[[ARG:.*]]: tensor<1x5x?xf32>)
|
||||
|
|
|
@ -673,6 +673,9 @@ func @tensor_from_elements() {
|
|||
// CHECK: %2 = tensor_from_elements [[C0_F32]] : tensor<1xf32>
|
||||
%2 = tensor_from_elements %c0_f32 : tensor<1xf32>
|
||||
|
||||
// CHECK: tensor_from_elements : tensor<0xindex>
|
||||
%3 = tensor_from_elements : tensor<0xindex>
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue