forked from OSchip/llvm-project
[MLIR][Shape] Lower `shape.get_extent` to `std.dim` when possible
When the shape is derived from a tensor argument the shape extent can be derived directly from that tensor with `std.dim`. This lowering pattern circumvents the necessity to materialize the shape in memory. Differential Revision: https://reviews.llvm.org/D82644
This commit is contained in:
parent
37cc4fa2ea
commit
76d72c941d
|
@ -19,3 +19,10 @@ def SizeToIndexOpConversion : Pat<
|
|||
(Shape_SizeToIndexOp $arg),
|
||||
(replaceWithValue $arg)>;
|
||||
|
||||
// Derive shape extent directly from shape origin if possible.
|
||||
// This circumvents the necessity to materialize the shape in memory.
|
||||
def GetExtentShapeOfConversion : Pat<
|
||||
(Shape_GetExtentOp (Shape_ShapeOfOp $arg), $idx),
|
||||
(Shape_IndexToSizeOp (DimOp $arg, (Shape_SizeToIndexOp $idx))),
|
||||
[],
|
||||
(addBenefit 10)>;
|
||||
|
|
|
@ -127,3 +127,19 @@ func @rank(%shape : !shape.shape) -> !shape.size {
|
|||
%rank = shape.rank %shape
|
||||
return %rank : !shape.size
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Express `get_extent` as `std.dim` when it relies directly on the outcome of a
|
||||
// `shape_of` operation.
|
||||
// CHECK-LABEL: @get_extent_shape_of
|
||||
// CHECK-SAME: (%[[ARG:.*]]: tensor<2x3xf32>, %[[IDX:.*]]: index) -> index
|
||||
func @get_extent_shape_of(%arg : tensor<2x3xf32>, %idx : !shape.size)
|
||||
-> !shape.size {
|
||||
// CHECK: %[[RESULT:.*]] = dim %[[ARG]], %[[IDX]] : tensor<2x3xf32>
|
||||
// CHECK: return %[[RESULT]] : index
|
||||
%shape = shape.shape_of %arg : tensor<2x3xf32>
|
||||
%result = shape.get_extent %shape, %idx
|
||||
return %result : !shape.size
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue