[mlir][shape] Add dim op

Convenience op that allows for simple expression of common crossing of
value/shape divide.

Differential Revision: https://reviews.llvm.org/D131497
This commit is contained in:
Jacques Pienaar 2022-08-12 11:02:07 -07:00
parent 42ee0d8c16
commit 2f025e0e78
5 changed files with 128 additions and 0 deletions

View File

@ -328,6 +328,41 @@ def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [
let hasFolder = 1;
}
def Shape_DimOp : Shape_Op<"dim",
[NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "Gets the specified extent from the shape of a shaped input";
let description = [{
Gets the extent indexed by `dim` from the shape of the `value` operand. If
the dim is error or out-of-bound then it returns an invalid size if the
return type carries error information else the behavior is undefined.
This is a convenience op that performs the equivalent of getting the extent
of a shape (e.g., `dim(x, i) == get_extent(shape_of(x), i)`).
}];
let arguments = (ins AnyShaped:$value,
Shape_SizeOrIndexType:$dim);
let results = (outs Shape_SizeOrIndexType:$extent);
let assemblyFormat = "$value `,` $dim attr-dict `:` type($value) `,` type($dim) `->` "
"type($extent)";
let builders = [
// Builder that allows passing a constant dimension as a simple integer.
OpBuilder<(ins "Value":$value, "int64_t":$dim)>
];
let extraClassDeclaration = [{
/// Get the `dim` value as integer if it is constant.
Optional<int64_t> getConstantDim();
/// Returns when two result types are compatible for this op; method used by
/// InferTypeOpInterface
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
}];
let hasFolder = 1;
let hasVerifier = 1;
}
def Shape_GetExtentOp : Shape_Op<"get_extent",
[NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "Gets the specified extent from a shape or extent tensor";

View File

@ -322,6 +322,28 @@ LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
return success();
}
namespace {
class DimOpConverter : public OpConversionPattern<DimOp> {
using OpConversionPattern<DimOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(DimOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
} // namespace
LogicalResult
DimOpConverter::matchAndRewrite(DimOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// Lower to dim(X, i) to get_extent(shape_of(X), i) and rely on further
// lowerings. This can be further optimized if needed to avoid intermediate
// steps.
auto shapeOf = rewriter.create<shape::ShapeOfOp>(op.getLoc(), op.getValue());
rewriter.replaceOpWithNewOp<shape::GetExtentOp>(op, op.getType(), shapeOf,
op.getDim());
return success();
}
namespace {
class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
using OpConversionPattern<GetExtentOp>::OpConversionPattern;
@ -693,6 +715,7 @@ void mlir::populateShapeToStandardConversionPatterns(
BroadcastOpConverter,
ConstShapeOpConverter,
ConstSizeOpConversion,
DimOpConverter,
IsBroadcastableOpConverter,
GetExtentOpConverter,
RankOpConverter,

View File

@ -1064,6 +1064,58 @@ OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) {
return operands[0];
}
//===----------------------------------------------------------------------===//
// DimOp
//===----------------------------------------------------------------------===//
Optional<int64_t> DimOp::getConstantDim() {
if (auto constSizeOp = getDim().getDefiningOp<ConstSizeOp>())
return constSizeOp.getValue().getLimitedValue();
if (auto constantOp = getDim().getDefiningOp<arith::ConstantOp>())
return constantOp.getValue().cast<IntegerAttr>().getInt();
return llvm::None;
}
OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
Type valType = getValue().getType();
auto valShapedType = valType.dyn_cast<ShapedType>();
if (!valShapedType || !valShapedType.hasRank())
return nullptr;
Optional<int64_t> dim = getConstantDim();
if (!dim.has_value())
return nullptr;
if (dim.value() >= valShapedType.getRank())
return nullptr;
auto extent = valShapedType.getDimSize(*dim);
if (ShapedType::isDynamic(extent))
return nullptr;
return IntegerAttr::get(IndexType::get(getContext()), extent);
}
LogicalResult mlir::shape::DimOp::inferReturnTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
DimOpAdaptor dimOp(operands);
inferredReturnTypes.assign({dimOp.getDim().getType()});
return success();
}
bool mlir::shape::DimOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
}
LogicalResult mlir::shape::DimOp::verify() {
auto st = getValue().getType().cast<ShapedType>();
if (!st.hasRank())
return success();
if (auto dim = getConstantDim()) {
if (*dim < 0 || *dim >= st.getRank())
return emitOpError("index is out of range");
}
return success();
}
//===----------------------------------------------------------------------===//
// DivOp
//===----------------------------------------------------------------------===//

View File

@ -60,6 +60,18 @@ func.func @rank(%shape : !shape.shape) {
// -----
// Express `shape.dim` as `tensor.dim` when valid.
// CHECK-LABEL: @dim
// CHECK-SAME: (%[[ARG:.*]]: tensor<2x3xf32>, %[[IDX:.*]]: index) -> index
func.func @dim(%arg : tensor<2x3xf32>, %idx : index) -> index {
// CHECK: %[[RESULT:.*]] = tensor.dim %[[ARG]], %[[IDX]] : tensor<2x3xf32>
// CHECK: return %[[RESULT]] : index
%result = shape.dim %arg, %idx : tensor<2x3xf32>, index -> index
return %result : index
}
// -----
// Express `get_extent` as `tensor.dim` when it relies directly on the outcome of a
// `shape_of` operation.
// CHECK-LABEL: @get_extent_shape_of

View File

@ -216,6 +216,12 @@ func.func @get_extent_on_extent_tensor(%arg : tensor<?xindex>) -> index {
return %result : index
}
func.func @get_dim(%arg : memref<?x?xindex>) -> index {
%c0 = arith.constant 0 : index
%result = shape.dim %arg, %c0 : memref<?x?xindex>, index -> index
return %result : index
}
func.func @get_extent_on_mixed_operands(%arg : tensor<?xindex>) -> !shape.size {
%c0 = shape.const_size 0
%result = shape.get_extent %arg, %c0 : tensor<?xindex>, !shape.size -> !shape.size