diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp index f239d1cfb4f0..b84b6ba3b5d6 100644 --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -103,6 +103,39 @@ LogicalResult ShapeOfOpConversion::matchAndRewrite( return success(); } +namespace { +class ConstShapeOpConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ConstShapeOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; +} // namespace + +LogicalResult ConstShapeOpConverter::matchAndRewrite( + ConstShapeOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + + // For now, this lowering supports only extent tensors, not `shape.shape` + // types. + if (op.getType().isa()) + return failure(); + + auto loc = op.getLoc(); + SmallVector extentOperands; + for (auto extent : op.shape()) { + extentOperands.push_back( + rewriter.create(loc, extent.getLimitedValue())); + } + Value tensor = rewriter.create(loc, extentOperands); + Type indexTy = rewriter.getIndexType(); + Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy); + rewriter.replaceOpWithNewOp(op, tensor, resultTy); + return success(); +} + namespace { class GetExtentOpConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -209,6 +242,7 @@ void mlir::populateShapeToStandardConversionPatterns( patterns.insert< AnyOpConversion, BinaryOpConversion, + ConstShapeOpConverter, BinaryOpConversion, GetExtentOpConverter, RankOpConverter, diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir index 9336402d86da..7f875f3bb19f 100644 --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -111,6 +111,22 @@ func @get_extent_from_extent_tensor(%extents : tensor, %idx : index) // ----- +// Lower `const_shape` to `tensor_from_elements`. +// CHECK-LABEL: @const_shape +// CHECK-SAME: () -> tensor +func @const_shape() -> tensor { + // CHECK: %[[C1:.*]] = constant 1 : index + // CHECK: %[[C2:.*]] = constant 2 : index + // CHECK: %[[C3:.*]] = constant 3 : index + // CHECK: %[[TENSOR3:.*]] = tensor_from_elements(%[[C1]], %[[C2]], %[[C3]]) + // CHECK: %[[RESULT:.*]] = tensor_cast %[[TENSOR3]] : tensor<3xindex> to tensor + // CHECK: return %[[RESULT]] : tensor + %shape = shape.const_shape [1, 2, 3] : tensor + return %shape : tensor +} + +// ----- + // Lower `any` to its first operand. // CHECK-LABEL: @any_of_three // CHECK-SAME: (%[[A:.*]]: tensor, %[[B:.*]]: tensor, %[[C:.*]]: tensor) -> tensor