diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td index 47825577921e..074a54f9e5ae 100644 --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -132,6 +132,30 @@ def Shape_ConstSizeOp : Shape_Op<"const_size", let assemblyFormat = "attr-dict $value"; } +def Shape_FromExtentsOp : Shape_Op<"from_extents", [ + NoSideEffect, + DeclareOpInterfaceMethods + ]> { + let summary = "Creates a shape from extents"; + let description = [{ + Creates a shape from multiple SSA values representing the extents of + the shape. + + ```mlir + // Rank 2 shape. + %s0 = shape.from_extents %a, %b + // Rank 0 shape. + %s1 = shape.from_extents + ``` + }]; + let arguments = (ins Variadic:$extents); + let results = (outs Shape_ShapeType:$shape); + + let assemblyFormat = "attr-dict $extents"; + + let hasFolder = 1; +} + def Shape_FromExtentTensorOp : Shape_Op<"from_extent_tensor", []> { let summary = "Creates a shape from a tensor of extents"; let description = [{ diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index a66fa8a8128a..e1d1b3365699 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -201,6 +201,28 @@ ConstSizeOp::inferReturnTypes(MLIRContext *context, Optional location, return success(); } +//===----------------------------------------------------------------------===// +// FromExtentsOp +//===----------------------------------------------------------------------===// + +LogicalResult FromExtentsOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + inferredReturnTypes.push_back(ShapeType::get(context)); + return success(); +} + +OpFoldResult FromExtentsOp::fold(ArrayRef operands) { + if (llvm::any_of(operands, [](Attribute a) { return !a; })) + return nullptr; + SmallVector extents; + for (auto attr : operands) + extents.push_back(attr.cast().getInt()); + Builder builder(getContext()); + return builder.getI64TensorAttr(extents); +} + //===----------------------------------------------------------------------===// // ShapeOfOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir index ee69f90553d9..2e35fc748d86 100644 --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -86,3 +86,23 @@ func @f() -> tensor<2xindex> { %0 = "shape.to_extent_tensor"(%cs) : (!shape.shape) -> tensor<2xindex> return %0 : tensor<2xindex> } + +// ----- +// Basic case. +// CHECK-LABEL: func @f() +func @f() -> !shape.shape { + // CHECK: shape.const_shape [3, 5, 11] + %e0 = constant 3 : index + %e1 = constant 5 : index + %e2 = constant 11 : index + %ret = shape.from_extents %e0, %e1, %e2 + return %ret : !shape.shape +} + +// CHECK-LABEL: func @no_fold +func @no_fold(%arg0: index) -> !shape.shape { + // CHECK-NOT: shape.const_shape + %e0 = constant 3 : index + %ret = shape.from_extents %e0, %arg0 + return %ret : !shape.shape +}