forked from OSchip/llvm-project
[mlir][shape] Add dim op
Convenience op that allows for simple expression of common crossing of value/shape divide. Differential Revision: https://reviews.llvm.org/D131497
This commit is contained in:
parent
42ee0d8c16
commit
2f025e0e78
|
@ -328,6 +328,41 @@ def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [
|
|||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Shape_DimOp : Shape_Op<"dim",
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||
let summary = "Gets the specified extent from the shape of a shaped input";
|
||||
let description = [{
|
||||
Gets the extent indexed by `dim` from the shape of the `value` operand. If
|
||||
the dim is error or out-of-bound then it returns an invalid size if the
|
||||
return type carries error information else the behavior is undefined.
|
||||
|
||||
This is a convenience op that performs the equivalent of getting the extent
|
||||
of a shape (e.g., `dim(x, i) == get_extent(shape_of(x), i)`).
|
||||
}];
|
||||
let arguments = (ins AnyShaped:$value,
|
||||
Shape_SizeOrIndexType:$dim);
|
||||
let results = (outs Shape_SizeOrIndexType:$extent);
|
||||
let assemblyFormat = "$value `,` $dim attr-dict `:` type($value) `,` type($dim) `->` "
|
||||
"type($extent)";
|
||||
|
||||
let builders = [
|
||||
// Builder that allows passing a constant dimension as a simple integer.
|
||||
OpBuilder<(ins "Value":$value, "int64_t":$dim)>
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// Get the `dim` value as integer if it is constant.
|
||||
Optional<int64_t> getConstantDim();
|
||||
|
||||
/// Returns when two result types are compatible for this op; method used by
|
||||
/// InferTypeOpInterface
|
||||
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
|
||||
}];
|
||||
|
||||
let hasFolder = 1;
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def Shape_GetExtentOp : Shape_Op<"get_extent",
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||
let summary = "Gets the specified extent from a shape or extent tensor";
|
||||
|
|
|
@ -322,6 +322,28 @@ LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
class DimOpConverter : public OpConversionPattern<DimOp> {
|
||||
using OpConversionPattern<DimOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(DimOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
LogicalResult
|
||||
DimOpConverter::matchAndRewrite(DimOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
// Lower to dim(X, i) to get_extent(shape_of(X), i) and rely on further
|
||||
// lowerings. This can be further optimized if needed to avoid intermediate
|
||||
// steps.
|
||||
auto shapeOf = rewriter.create<shape::ShapeOfOp>(op.getLoc(), op.getValue());
|
||||
rewriter.replaceOpWithNewOp<shape::GetExtentOp>(op, op.getType(), shapeOf,
|
||||
op.getDim());
|
||||
return success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
|
||||
using OpConversionPattern<GetExtentOp>::OpConversionPattern;
|
||||
|
@ -693,6 +715,7 @@ void mlir::populateShapeToStandardConversionPatterns(
|
|||
BroadcastOpConverter,
|
||||
ConstShapeOpConverter,
|
||||
ConstSizeOpConversion,
|
||||
DimOpConverter,
|
||||
IsBroadcastableOpConverter,
|
||||
GetExtentOpConverter,
|
||||
RankOpConverter,
|
||||
|
|
|
@ -1064,6 +1064,58 @@ OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) {
|
|||
return operands[0];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DimOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Optional<int64_t> DimOp::getConstantDim() {
|
||||
if (auto constSizeOp = getDim().getDefiningOp<ConstSizeOp>())
|
||||
return constSizeOp.getValue().getLimitedValue();
|
||||
if (auto constantOp = getDim().getDefiningOp<arith::ConstantOp>())
|
||||
return constantOp.getValue().cast<IntegerAttr>().getInt();
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
|
||||
Type valType = getValue().getType();
|
||||
auto valShapedType = valType.dyn_cast<ShapedType>();
|
||||
if (!valShapedType || !valShapedType.hasRank())
|
||||
return nullptr;
|
||||
Optional<int64_t> dim = getConstantDim();
|
||||
if (!dim.has_value())
|
||||
return nullptr;
|
||||
if (dim.value() >= valShapedType.getRank())
|
||||
return nullptr;
|
||||
auto extent = valShapedType.getDimSize(*dim);
|
||||
if (ShapedType::isDynamic(extent))
|
||||
return nullptr;
|
||||
return IntegerAttr::get(IndexType::get(getContext()), extent);
|
||||
}
|
||||
|
||||
LogicalResult mlir::shape::DimOp::inferReturnTypes(
|
||||
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
DimOpAdaptor dimOp(operands);
|
||||
inferredReturnTypes.assign({dimOp.getDim().getType()});
|
||||
return success();
|
||||
}
|
||||
|
||||
bool mlir::shape::DimOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
|
||||
return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
|
||||
}
|
||||
|
||||
LogicalResult mlir::shape::DimOp::verify() {
|
||||
auto st = getValue().getType().cast<ShapedType>();
|
||||
if (!st.hasRank())
|
||||
return success();
|
||||
if (auto dim = getConstantDim()) {
|
||||
if (*dim < 0 || *dim >= st.getRank())
|
||||
return emitOpError("index is out of range");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DivOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -60,6 +60,18 @@ func.func @rank(%shape : !shape.shape) {
|
|||
|
||||
// -----
|
||||
|
||||
// Express `shape.dim` as `tensor.dim` when valid.
|
||||
// CHECK-LABEL: @dim
|
||||
// CHECK-SAME: (%[[ARG:.*]]: tensor<2x3xf32>, %[[IDX:.*]]: index) -> index
|
||||
func.func @dim(%arg : tensor<2x3xf32>, %idx : index) -> index {
|
||||
// CHECK: %[[RESULT:.*]] = tensor.dim %[[ARG]], %[[IDX]] : tensor<2x3xf32>
|
||||
// CHECK: return %[[RESULT]] : index
|
||||
%result = shape.dim %arg, %idx : tensor<2x3xf32>, index -> index
|
||||
return %result : index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Express `get_extent` as `tensor.dim` when it relies directly on the outcome of a
|
||||
// `shape_of` operation.
|
||||
// CHECK-LABEL: @get_extent_shape_of
|
||||
|
|
|
@ -216,6 +216,12 @@ func.func @get_extent_on_extent_tensor(%arg : tensor<?xindex>) -> index {
|
|||
return %result : index
|
||||
}
|
||||
|
||||
func.func @get_dim(%arg : memref<?x?xindex>) -> index {
|
||||
%c0 = arith.constant 0 : index
|
||||
%result = shape.dim %arg, %c0 : memref<?x?xindex>, index -> index
|
||||
return %result : index
|
||||
}
|
||||
|
||||
func.func @get_extent_on_mixed_operands(%arg : tensor<?xindex>) -> !shape.size {
|
||||
%c0 = shape.const_size 0
|
||||
%result = shape.get_extent %arg, %c0 : tensor<?xindex>, !shape.size -> !shape.size
|
||||
|
|
Loading…
Reference in New Issue