forked from OSchip/llvm-project
[mlir] Fix some edge cases around 0-element TensorFromElementsOp
This introduces a builder for the more general case that supports zero elements (where the element type can't be inferred from the ValueRange, since it might be empty). Also, fix up some cases in ShapeToStandard lowering that hit this. It happens very easily when dealing with shapes of 0-D tensors. The SameOperandsAndResultElementType is redundant with the new TypesMatchWith and prevented having zero elements. Differential Revision: https://reviews.llvm.org/D87492
This commit is contained in:
parent
aeb4314391
commit
84a6da67e6
|
@ -1613,7 +1613,6 @@ def ExtractElementOp : Std_Op<"extract_element",
|
||||||
|
|
||||||
def TensorFromElementsOp : Std_Op<"tensor_from_elements", [
|
def TensorFromElementsOp : Std_Op<"tensor_from_elements", [
|
||||||
NoSideEffect,
|
NoSideEffect,
|
||||||
SameOperandsAndResultElementType,
|
|
||||||
TypesMatchWith<"operand types match result element type",
|
TypesMatchWith<"operand types match result element type",
|
||||||
"result", "elements", "SmallVector<Type, 2>("
|
"result", "elements", "SmallVector<Type, 2>("
|
||||||
"$_self.cast<ShapedType>().getDimSize(0), "
|
"$_self.cast<ShapedType>().getDimSize(0), "
|
||||||
|
@ -1638,7 +1637,11 @@ def TensorFromElementsOp : Std_Op<"tensor_from_elements", [
|
||||||
// This op is fully verified by its traits.
|
// This op is fully verified by its traits.
|
||||||
let verifier = ?;
|
let verifier = ?;
|
||||||
|
|
||||||
|
let skipDefaultBuilders = 1;
|
||||||
let builders = [
|
let builders = [
|
||||||
|
OpBuilder<"OpBuilder &b, OperationState &result, Type elementType,"
|
||||||
|
"ValueRange elements">,
|
||||||
|
// Special case builder for when `elements` has size >=1.
|
||||||
OpBuilder<"OpBuilder &b, OperationState &result, ValueRange elements">
|
OpBuilder<"OpBuilder &b, OperationState &result, ValueRange elements">
|
||||||
];
|
];
|
||||||
|
|
||||||
|
|
|
@ -182,8 +182,9 @@ LogicalResult ConstShapeOpConverter::matchAndRewrite(
|
||||||
extentOperands.push_back(
|
extentOperands.push_back(
|
||||||
rewriter.create<ConstantIndexOp>(loc, extent.getLimitedValue()));
|
rewriter.create<ConstantIndexOp>(loc, extent.getLimitedValue()));
|
||||||
}
|
}
|
||||||
Value tensor = rewriter.create<TensorFromElementsOp>(loc, extentOperands);
|
|
||||||
Type indexTy = rewriter.getIndexType();
|
Type indexTy = rewriter.getIndexType();
|
||||||
|
Value tensor =
|
||||||
|
rewriter.create<TensorFromElementsOp>(loc, indexTy, extentOperands);
|
||||||
Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
|
Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
|
||||||
rewriter.replaceOpWithNewOp<TensorCastOp>(op, tensor, resultTy);
|
rewriter.replaceOpWithNewOp<TensorCastOp>(op, tensor, resultTy);
|
||||||
return success();
|
return success();
|
||||||
|
@ -444,8 +445,8 @@ LogicalResult ShapeOfOpConversion::matchAndRewrite(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Materialize extent tensor.
|
// Materialize extent tensor.
|
||||||
Value staticExtentTensor =
|
Value staticExtentTensor = rewriter.create<TensorFromElementsOp>(
|
||||||
rewriter.create<TensorFromElementsOp>(loc, extentValues);
|
loc, rewriter.getIndexType(), extentValues);
|
||||||
rewriter.replaceOpWithNewOp<TensorCastOp>(op, staticExtentTensor,
|
rewriter.replaceOpWithNewOp<TensorCastOp>(op, staticExtentTensor,
|
||||||
op.getType());
|
op.getType());
|
||||||
return success();
|
return success();
|
||||||
|
|
|
@ -1756,12 +1756,18 @@ OpFoldResult ExtractElementOp::fold(ArrayRef<Attribute> operands) {
|
||||||
// TensorFromElementsOp
|
// TensorFromElementsOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
void TensorFromElementsOp::build(OpBuilder &builder, OperationState &result,
|
||||||
|
Type elementType, ValueRange elements) {
|
||||||
|
Type resultTy = RankedTensorType::get({static_cast<int64_t>(elements.size())},
|
||||||
|
elementType);
|
||||||
|
result.addOperands(elements);
|
||||||
|
result.addTypes(resultTy);
|
||||||
|
}
|
||||||
|
|
||||||
void TensorFromElementsOp::build(OpBuilder &builder, OperationState &result,
|
void TensorFromElementsOp::build(OpBuilder &builder, OperationState &result,
|
||||||
ValueRange elements) {
|
ValueRange elements) {
|
||||||
assert(!elements.empty() && "expected at least one element");
|
assert(!elements.empty() && "expected at least one element");
|
||||||
Type resultTy = RankedTensorType::get({static_cast<int64_t>(elements.size())},
|
build(builder, result, elements.front().getType(), elements);
|
||||||
elements.front().getType());
|
|
||||||
build(builder, result, resultTy, elements);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
|
@ -103,6 +103,19 @@ func @const_shape() -> tensor<?xindex> {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// Lower `const_shape` in the case of rank 0.
|
||||||
|
// CHECK-LABEL: func @const_shape_zero_elements
|
||||||
|
// CHECK-SAME: () -> tensor<?xindex>
|
||||||
|
func @const_shape_zero_elements() -> tensor<?xindex> {
|
||||||
|
// CHECK: %[[TENSOR:.*]] = tensor_from_elements : tensor<0xindex>
|
||||||
|
// CHECK: %[[RESULT:.*]] = tensor_cast %[[TENSOR]] : tensor<0xindex> to tensor<?xindex>
|
||||||
|
// CHECK: return %[[RESULT]] : tensor<?xindex>
|
||||||
|
%shape = shape.const_shape [] : tensor<?xindex>
|
||||||
|
return %shape : tensor<?xindex>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// Lower `any` to its first operand.
|
// Lower `any` to its first operand.
|
||||||
// CHECK-LABEL: @any_of_three
|
// CHECK-LABEL: @any_of_three
|
||||||
// CHECK-SAME: (%[[A:.*]]: tensor<?xindex>, %[[B:.*]]: tensor<?xindex>, %[[C:.*]]: tensor<?xindex>) -> tensor<?xindex>
|
// CHECK-SAME: (%[[A:.*]]: tensor<?xindex>, %[[B:.*]]: tensor<?xindex>, %[[C:.*]]: tensor<?xindex>) -> tensor<?xindex>
|
||||||
|
@ -227,6 +240,17 @@ func @shape_of_stat(%arg : tensor<1x2x3xf32>) {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// Lower `shape_of` for 0-D tensor.
|
||||||
|
// CHECK-LABEL: @shape_of_zero_d
|
||||||
|
// CHECK-SAME: (%[[ARG:.*]]: tensor<f32>)
|
||||||
|
func @shape_of_zero_d(%arg : tensor<f32>) {
|
||||||
|
// CHECK-DAG: %[[SHAPE_UNCASTED:.*]] = tensor_from_elements : tensor<0xindex>
|
||||||
|
%shape = shape.shape_of %arg : tensor<f32> -> tensor<?xindex>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// Lower `shape_of` for dynamically shaped tensor.
|
// Lower `shape_of` for dynamically shaped tensor.
|
||||||
// CHECK-LABEL: @shape_of_dyn
|
// CHECK-LABEL: @shape_of_dyn
|
||||||
// CHECK-SAME: (%[[ARG:.*]]: tensor<1x5x?xf32>)
|
// CHECK-SAME: (%[[ARG:.*]]: tensor<1x5x?xf32>)
|
||||||
|
|
|
@ -673,6 +673,9 @@ func @tensor_from_elements() {
|
||||||
// CHECK: %2 = tensor_from_elements [[C0_F32]] : tensor<1xf32>
|
// CHECK: %2 = tensor_from_elements [[C0_F32]] : tensor<1xf32>
|
||||||
%2 = tensor_from_elements %c0_f32 : tensor<1xf32>
|
%2 = tensor_from_elements %c0_f32 : tensor<1xf32>
|
||||||
|
|
||||||
|
// CHECK: tensor_from_elements : tensor<0xindex>
|
||||||
|
%3 = tensor_from_elements : tensor<0xindex>
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue