forked from OSchip/llvm-project
[MLIR][Shape] Lower `shape.const_shape` to `tensor_from_elements`
Differential Revision: https://reviews.llvm.org/D82848
This commit is contained in:
parent
a4edc04693
commit
dfcc09890a
|
@ -103,6 +103,39 @@ LogicalResult ShapeOfOpConversion::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> {
|
||||
public:
|
||||
using OpConversionPattern<ConstShapeOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(ConstShapeOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
LogicalResult ConstShapeOpConverter::matchAndRewrite(
|
||||
ConstShapeOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
// For now, this lowering supports only extent tensors, not `shape.shape`
|
||||
// types.
|
||||
if (op.getType().isa<ShapeType>())
|
||||
return failure();
|
||||
|
||||
auto loc = op.getLoc();
|
||||
SmallVector<Value, 4> extentOperands;
|
||||
for (auto extent : op.shape()) {
|
||||
extentOperands.push_back(
|
||||
rewriter.create<ConstantIndexOp>(loc, extent.getLimitedValue()));
|
||||
}
|
||||
Value tensor = rewriter.create<TensorFromElementsOp>(loc, extentOperands);
|
||||
Type indexTy = rewriter.getIndexType();
|
||||
Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy);
|
||||
rewriter.replaceOpWithNewOp<TensorCastOp>(op, tensor, resultTy);
|
||||
return success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
|
||||
using OpConversionPattern<GetExtentOp>::OpConversionPattern;
|
||||
|
@ -209,6 +242,7 @@ void mlir::populateShapeToStandardConversionPatterns(
|
|||
patterns.insert<
|
||||
AnyOpConversion,
|
||||
BinaryOpConversion<AddOp, AddIOp>,
|
||||
ConstShapeOpConverter,
|
||||
BinaryOpConversion<MulOp, MulIOp>,
|
||||
GetExtentOpConverter,
|
||||
RankOpConverter,
|
||||
|
|
|
@ -111,6 +111,22 @@ func @get_extent_from_extent_tensor(%extents : tensor<?xindex>, %idx : index)
|
|||
|
||||
// -----
|
||||
|
||||
// Lower `const_shape` to `tensor_from_elements`.
|
||||
// CHECK-LABEL: @const_shape
|
||||
// CHECK-SAME: () -> tensor<?xindex>
|
||||
func @const_shape() -> tensor<?xindex> {
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[C2:.*]] = constant 2 : index
|
||||
// CHECK: %[[C3:.*]] = constant 3 : index
|
||||
// CHECK: %[[TENSOR3:.*]] = tensor_from_elements(%[[C1]], %[[C2]], %[[C3]])
|
||||
// CHECK: %[[RESULT:.*]] = tensor_cast %[[TENSOR3]] : tensor<3xindex> to tensor<?xindex>
|
||||
// CHECK: return %[[RESULT]] : tensor<?xindex>
|
||||
%shape = shape.const_shape [1, 2, 3] : tensor<?xindex>
|
||||
return %shape : tensor<?xindex>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Lower `any` to its first operand.
|
||||
// CHECK-LABEL: @any_of_three
|
||||
// CHECK-SAME: (%[[A:.*]]: tensor<?xindex>, %[[B:.*]]: tensor<?xindex>, %[[C:.*]]: tensor<?xindex>) -> tensor<?xindex>
|
||||
|
|
Loading…
Reference in New Issue