forked from OSchip/llvm-project
[mlir][shape] Add `shape.from_extents`.
Summary: This is a basic op needed for creating shapes from SSA values representing the extents. Differential Revision: https://reviews.llvm.org/D79833
This commit is contained in:
parent
47650dcbee
commit
21b0eff773
|
@ -132,6 +132,30 @@ def Shape_ConstSizeOp : Shape_Op<"const_size",
|
|||
let assemblyFormat = "attr-dict $value";
|
||||
}
|
||||
|
||||
def Shape_FromExtentsOp : Shape_Op<"from_extents", [
|
||||
NoSideEffect,
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>
|
||||
]> {
|
||||
let summary = "Creates a shape from extents";
|
||||
let description = [{
|
||||
Creates a shape from multiple SSA values representing the extents of
|
||||
the shape.
|
||||
|
||||
```mlir
|
||||
// Rank 2 shape.
|
||||
%s0 = shape.from_extents %a, %b
|
||||
// Rank 0 shape.
|
||||
%s1 = shape.from_extents
|
||||
```
|
||||
}];
|
||||
let arguments = (ins Variadic<Index>:$extents);
|
||||
let results = (outs Shape_ShapeType:$shape);
|
||||
|
||||
let assemblyFormat = "attr-dict $extents";
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Shape_FromExtentTensorOp : Shape_Op<"from_extent_tensor", []> {
|
||||
let summary = "Creates a shape from a tensor of extents";
|
||||
let description = [{
|
||||
|
|
|
@ -201,6 +201,28 @@ ConstSizeOp::inferReturnTypes(MLIRContext *context, Optional<Location> location,
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FromExtentsOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult FromExtentsOp::inferReturnTypes(
|
||||
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
inferredReturnTypes.push_back(ShapeType::get(context));
|
||||
return success();
|
||||
}
|
||||
|
||||
OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (llvm::any_of(operands, [](Attribute a) { return !a; }))
|
||||
return nullptr;
|
||||
SmallVector<int64_t, 6> extents;
|
||||
for (auto attr : operands)
|
||||
extents.push_back(attr.cast<IntegerAttr>().getInt());
|
||||
Builder builder(getContext());
|
||||
return builder.getI64TensorAttr(extents);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ShapeOfOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -86,3 +86,23 @@ func @f() -> tensor<2xindex> {
|
|||
%0 = "shape.to_extent_tensor"(%cs) : (!shape.shape) -> tensor<2xindex>
|
||||
return %0 : tensor<2xindex>
|
||||
}
|
||||
|
||||
// -----
|
||||
// Basic case.
|
||||
// CHECK-LABEL: func @f()
|
||||
func @f() -> !shape.shape {
|
||||
// CHECK: shape.const_shape [3, 5, 11]
|
||||
%e0 = constant 3 : index
|
||||
%e1 = constant 5 : index
|
||||
%e2 = constant 11 : index
|
||||
%ret = shape.from_extents %e0, %e1, %e2
|
||||
return %ret : !shape.shape
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @no_fold
|
||||
func @no_fold(%arg0: index) -> !shape.shape {
|
||||
// CHECK-NOT: shape.const_shape
|
||||
%e0 = constant 3 : index
|
||||
%ret = shape.from_extents %e0, %arg0
|
||||
return %ret : !shape.shape
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue