From 41e5dbe0fa95933c60bd70eda65af0f2d0243e39 Mon Sep 17 00:00:00 2001 From: Chia-hung Duan Date: Wed, 18 Aug 2021 20:46:26 +0000 Subject: [PATCH] Enables inferring return types for Shape op if possible Reviewed By: jpienaar Differential Revision: https://reviews.llvm.org/D102565 --- .../include/mlir/Dialect/Shape/IR/ShapeOps.td | 100 ++++++-- .../mlir/Interfaces/InferTypeOpInterface.td | 4 +- mlir/lib/Dialect/Shape/IR/Shape.cpp | 241 ++++++++++++++++-- mlir/test/Dialect/Shape/invalid.mlir | 8 + 4 files changed, 316 insertions(+), 37 deletions(-) diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td index d415bb8b5622..6b39fbffe9d4 100644 --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -28,7 +28,9 @@ include "mlir/IR/SymbolInterfaces.td" class Shape_Op traits = []> : Op; -def Shape_AddOp : Shape_Op<"add", [Commutative, NoSideEffect]> { +def Shape_AddOp : Shape_Op<"add", + [Commutative, NoSideEffect, + DeclareOpInterfaceMethods]> { let summary = "Addition of sizes and indices"; let description = [{ Adds two sizes or indices. If either operand is an error it will be @@ -47,6 +49,12 @@ def Shape_AddOp : Shape_Op<"add", [Commutative, NoSideEffect]> { }]; let verifier = [{ return verifySizeOrIndexOp(*this); }]; + + let extraClassDeclaration = [{ + // Returns when two result types are compatible for this op; method used by + // InferTypeOpInterface + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); + }]; } def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative, NoSideEffect]> { @@ -77,6 +85,8 @@ def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative, NoSideEffect]> { OptionalAttr:$error); let results = (outs Shape_ShapeOrExtentTensorType:$result); + let builders = [OpBuilder<(ins "Value":$shape)>]; + let assemblyFormat = [{ $shapes attr-dict `:` type($shapes) `->` type($result) }]; @@ -145,7 +155,8 @@ def Shape_ConstSizeOp : Shape_Op<"const_size", [ let hasFolder = 1; } -def Shape_DivOp : Shape_Op<"div", [NoSideEffect]> { +def Shape_DivOp : Shape_Op<"div", [NoSideEffect, + DeclareOpInterfaceMethods]> { let summary = "Division of sizes and indices"; let description = [{ Divides two sizes or indices. If either operand is an error it will be @@ -173,10 +184,16 @@ def Shape_DivOp : Shape_Op<"div", [NoSideEffect]> { let verifier = [{ return ::verifySizeOrIndexOp(*this); }]; let hasFolder = 1; + + let extraClassDeclaration = [{ + // Returns when two result types are compatible for this op; method used by + // InferTypeOpInterface + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); + }]; } -def Shape_ShapeEqOp : Shape_Op<"shape_eq", [NoSideEffect, Commutative, - InferTypeOpInterface]> { +def Shape_ShapeEqOp : Shape_Op<"shape_eq", + [NoSideEffect, Commutative, InferTypeOpInterface]> { let summary = "Returns whether the input shapes or extent tensors are equal"; let description = [{ Takes one or more shape or extent tensor operands and determines whether @@ -290,7 +307,8 @@ def Shape_IsBroadcastableOp : Shape_Op<"is_broadcastable", let assemblyFormat = "$shapes attr-dict `:` type($shapes)"; } -def Shape_RankOp : Shape_Op<"rank", [NoSideEffect]> { +def Shape_RankOp : Shape_Op<"rank", + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "Gets the rank of a shape"; let description = [{ Returns the rank of the shape or extent tensor, i.e. the number of extents. @@ -304,6 +322,12 @@ def Shape_RankOp : Shape_Op<"rank", [NoSideEffect]> { let hasFolder = 1; let hasCanonicalizer = 1; let verifier = [{ return ::verifySizeOrIndexOp(*this); }]; + + let extraClassDeclaration = [{ + // Returns when two result types are compatible for this op; method used by + // InferTypeOpInterface + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); + }]; } def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [NoSideEffect]> { @@ -324,7 +348,8 @@ def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [NoSideEffect]> { let hasFolder = 1; } -def Shape_GetExtentOp : Shape_Op<"get_extent", [NoSideEffect]> { +def Shape_GetExtentOp : Shape_Op<"get_extent", + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "Gets the specified extent from a shape or extent tensor"; let description = [{ Gets the extent indexed by `dim` from the `shape` operand. If the shape is @@ -344,6 +369,9 @@ def Shape_GetExtentOp : Shape_Op<"get_extent", [NoSideEffect]> { let extraClassDeclaration = [{ /// Get the `dim` value as integer if it is constant. Optional getConstantDim(); + /// Returns when two result types are compatible for this op; method used by + /// InferTypeOpInterface + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); }]; let hasFolder = 1; @@ -369,7 +397,8 @@ def Shape_IndexToSizeOp : Shape_Op<"index_to_size", [NoSideEffect]> { let hasCanonicalizer = 1; } -def Shape_JoinOp : Shape_Op<"join", [Commutative]> { +def Shape_JoinOp : Shape_Op<"join", + [Commutative, DeclareOpInterfaceMethods]> { let summary = "Returns the least general shape.shape of its operands"; let description = [{ An operation that computes the least general shape of input operands. @@ -405,9 +434,17 @@ def Shape_JoinOp : Shape_Op<"join", [Commutative]> { $arg0 `,` $arg1 (`,` `error` `=` $error^)? attr-dict `:` type($arg0) `,` type($arg1) `->` type($result) }]; + + let extraClassDeclaration = [{ + // Returns when two result types are compatible for this op; method used by + // InferTypeOpInterface + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); + }]; } -def Shape_MaxOp : Shape_Op<"max", [Commutative, NoSideEffect]> { +def Shape_MaxOp : Shape_Op<"max", + [Commutative, NoSideEffect, + DeclareOpInterfaceMethods]> { let summary = "Elementwise maximum"; let description = [{ Computes the elementwise maximum of two sizes or shapes with equal ranks. @@ -424,9 +461,17 @@ def Shape_MaxOp : Shape_Op<"max", [Commutative, NoSideEffect]> { }]; let hasFolder = 1; + + let extraClassDeclaration = [{ + // Returns when two result types are compatible for this op; method used by + // InferTypeOpInterface + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); + }]; } -def Shape_MinOp : Shape_Op<"min", [Commutative, NoSideEffect]> { +def Shape_MinOp : Shape_Op<"min", + [Commutative, NoSideEffect, + DeclareOpInterfaceMethods]> { let summary = "Elementwise minimum"; let description = [{ Computes the elementwise minimum of two sizes or shapes with equal ranks. @@ -443,9 +488,17 @@ def Shape_MinOp : Shape_Op<"min", [Commutative, NoSideEffect]> { }]; let hasFolder = 1; + + let extraClassDeclaration = [{ + // Returns when two result types are compatible for this op; method used by + // InferTypeOpInterface + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); + }]; } -def Shape_MulOp : Shape_Op<"mul", [Commutative, NoSideEffect]> { +def Shape_MulOp : Shape_Op<"mul", + [Commutative, NoSideEffect, + DeclareOpInterfaceMethods]> { let summary = "Multiplication of sizes and indices"; let description = [{ Multiplies two sizes or indices. If either operand is an error it will be @@ -465,9 +518,16 @@ def Shape_MulOp : Shape_Op<"mul", [Commutative, NoSideEffect]> { let verifier = [{ return ::verifySizeOrIndexOp(*this); }]; let hasFolder = 1; + + let extraClassDeclaration = [{ + // Returns when two result types are compatible for this op; method used by + // InferTypeOpInterface + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); + }]; } -def Shape_NumElementsOp : Shape_Op<"num_elements", [NoSideEffect]> { +def Shape_NumElementsOp : Shape_Op<"num_elements", + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "Returns the number of elements for a given shape"; let description = [{ Returns the number of elements for a given shape which is the product of its @@ -480,12 +540,15 @@ def Shape_NumElementsOp : Shape_Op<"num_elements", [NoSideEffect]> { let arguments = (ins Shape_ShapeOrExtentTensorType:$shape); let results = (outs Shape_SizeOrIndexType:$result); - let builders = [OpBuilder<(ins "Value":$shape)>]; - let assemblyFormat = "$shape attr-dict `:` type($shape) `->` type($result)"; let hasFolder = 1; let verifier = [{ return ::verifySizeOrIndexOp(*this); }]; + let extraClassDeclaration = [{ + // Returns when two result types are compatible for this op; method used by + // InferTypeOpInterface + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); + }]; } def Shape_ReduceOp : Shape_Op<"reduce", @@ -535,7 +598,8 @@ def Shape_ReduceOp : Shape_Op<"reduce", let parser = [{ return ::parse$cppClass(parser, result); }]; } -def Shape_ShapeOfOp : Shape_Op<"shape_of", [NoSideEffect]> { +def Shape_ShapeOfOp : Shape_Op<"shape_of", + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "Returns shape of a value or shaped type operand"; let description = [{ @@ -548,11 +612,15 @@ def Shape_ShapeOfOp : Shape_Op<"shape_of", [NoSideEffect]> { let assemblyFormat = "$arg attr-dict `:` type($arg) `->` type($result)"; - let builders = [OpBuilder<(ins "Value":$arg)>]; - let verifier = [{ return ::verifyShapeOrExtentTensorOp(*this); }]; let hasCanonicalizer = 1; let hasFolder = 1; + + let extraClassDeclaration = [{ + // Returns when two result types are compatible for this op; method used by + // InferTypeOpInterface + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); + }]; } def Shape_SizeToIndexOp : Shape_Op<"size_to_index", [NoSideEffect]> { diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td index fe7c8eeb2e13..1f604e25bf91 100644 --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td @@ -34,7 +34,9 @@ def InferTypeOpInterface : OpInterface<"InferTypeOpInterface"> { The method takes an optional location which, if set, will be used to report errors on. The operands and attributes correspond to those with which an Operation would be created (e.g., as used in Operation::create) - and the regions of the op. + and the regions of the op. Be aware that this method is supposed to be + called with valid arguments, e.g., operands are verified, or it may result + in an undefined behavior. }], /*retTy=*/"::mlir::LogicalResult", /*methodName=*/"inferReturnTypes", diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index f75bfc5894b6..7c17455cb3ae 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -89,6 +89,16 @@ static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) { return success(); } +template +static bool eachHasOnlyOneOfTypes(TypeRange typeRange) { + return typeRange.size() == 1 && typeRange.front().isa(); +} + +template +static bool eachHasOnlyOneOfTypes(TypeRange l, ranges... rs) { + return eachHasOnlyOneOfTypes(l) && eachHasOnlyOneOfTypes(rs...); +} + //===----------------------------------------------------------------------===// // InlinerInterface //===----------------------------------------------------------------------===// @@ -404,6 +414,27 @@ void AssumingOp::build( result.addTypes(assumingTypes); } +//===----------------------------------------------------------------------===// +// AddOp +//===----------------------------------------------------------------------===// + +LogicalResult mlir::shape::AddOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + if (operands[0].getType().isa() || + operands[1].getType().isa()) + inferredReturnTypes.assign({SizeType::get(context)}); + else + inferredReturnTypes.assign({IndexType::get(context)}); + return success(); +} + +bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { + // SizeType is compatible with IndexType. + return eachHasOnlyOneOfTypes(l, r); +} + //===----------------------------------------------------------------------===// // AssumingAllOp //===----------------------------------------------------------------------===// @@ -955,6 +986,23 @@ OpFoldResult DivOp::fold(ArrayRef operands) { return IntegerAttr::get(indexTy, quotient); } +LogicalResult mlir::shape::DivOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + if (operands[0].getType().isa() || + operands[1].getType().isa()) + inferredReturnTypes.assign({SizeType::get(context)}); + else + inferredReturnTypes.assign({IndexType::get(context)}); + return success(); +} + +bool mlir::shape::DivOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { + // SizeType is compatible with IndexType. + return eachHasOnlyOneOfTypes(l, r); +} + //===----------------------------------------------------------------------===// // ShapeEqOp //===----------------------------------------------------------------------===// @@ -1096,6 +1144,20 @@ void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape, } } +LogicalResult mlir::shape::GetExtentOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + inferredReturnTypes.assign({IndexType::get(context)}); + return success(); +} + +bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(TypeRange l, + TypeRange r) { + // SizeType is compatible with IndexType. + return eachHasOnlyOneOfTypes(l, r); +} + //===----------------------------------------------------------------------===// // IsBroadcastableOp //===----------------------------------------------------------------------===// @@ -1114,6 +1176,38 @@ OpFoldResult IsBroadcastableOp::fold(ArrayRef operands) { return nullptr; } +//===----------------------------------------------------------------------===// +// JoinOp +//===----------------------------------------------------------------------===// + +LogicalResult mlir::shape::JoinOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + inferredReturnTypes.assign({operands[0].getType()}); + return success(); +} + +bool mlir::shape::JoinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { + if (l.size() != 1 || r.size() != 1) + return false; + if (l == r) + return true; + + Type lhs = l.front(); + Type rhs = r.front(); + + if (lhs != rhs) + return false; + + if (lhs.isa() || lhs.isa()) + return true; + + if (succeeded(verifyCompatibleShapes({lhs, rhs}))) + return true; + return false; +} + //===----------------------------------------------------------------------===// // RankOp //===----------------------------------------------------------------------===// @@ -1173,6 +1267,22 @@ void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns, patterns.add(context); } +LogicalResult mlir::shape::RankOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + if (operands[0].getType().isa()) + inferredReturnTypes.assign({SizeType::get(context)}); + else + inferredReturnTypes.assign({IndexType::get(context)}); + return success(); +} + +bool mlir::shape::RankOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { + // SizeType is compatible with IndexType. + return eachHasOnlyOneOfTypes(l, r); +} + //===----------------------------------------------------------------------===// // NumElementsOp //===----------------------------------------------------------------------===// @@ -1191,14 +1301,21 @@ OpFoldResult NumElementsOp::fold(ArrayRef operands) { return builder.getIndexAttr(product.getLimitedValue()); } -void NumElementsOp::build(OpBuilder &builder, OperationState &result, - Value shape) { - if (shape.getType().isa()) { - auto type = builder.getIndexType(); - return build(builder, result, type, shape); - } - auto type = SizeType::get(builder.getContext()); - return build(builder, result, type, shape); +LogicalResult mlir::shape::NumElementsOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + if (operands[0].getType().isa()) + inferredReturnTypes.assign({SizeType::get(context)}); + else + inferredReturnTypes.assign({IndexType::get(context)}); + return success(); +} + +bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(TypeRange l, + TypeRange r) { + // SizeType is compatible with IndexType. + return eachHasOnlyOneOfTypes(l, r); } //===----------------------------------------------------------------------===// @@ -1212,6 +1329,27 @@ OpFoldResult MaxOp::fold(llvm::ArrayRef operands) { return nullptr; } +LogicalResult mlir::shape::MaxOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + if (operands[0].getType() == operands[1].getType()) + inferredReturnTypes.assign({operands[0].getType()}); + else + inferredReturnTypes.assign({SizeType::get(context)}); + return success(); +} + +bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { + if (l.size() != 1 || r.size() != 1) + return false; + if (l.front().isa() && r.front().isa()) + return true; + if (l.front().isa() && r.front().isa()) + return true; + return false; +} + //===----------------------------------------------------------------------===// // MinOp //===----------------------------------------------------------------------===// @@ -1223,6 +1361,27 @@ OpFoldResult MinOp::fold(llvm::ArrayRef operands) { return nullptr; } +LogicalResult mlir::shape::MinOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + if (operands[0].getType() == operands[1].getType()) + inferredReturnTypes.assign({operands[0].getType()}); + else + inferredReturnTypes.assign({SizeType::get(context)}); + return success(); +} + +bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { + if (l.size() != 1 || r.size() != 1) + return false; + if (l.front().isa() && r.front().isa()) + return true; + if (l.front().isa() && r.front().isa()) + return true; + return false; +} + //===----------------------------------------------------------------------===// // MulOp //===----------------------------------------------------------------------===// @@ -1239,6 +1398,22 @@ OpFoldResult MulOp::fold(ArrayRef operands) { return IntegerAttr::get(indexTy, folded); } +LogicalResult mlir::shape::MulOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + if (operands[0].getType().isa() || + operands[1].getType().isa()) + inferredReturnTypes.assign({SizeType::get(context)}); + else + inferredReturnTypes.assign({IndexType::get(context)}); + return success(); +} + +bool mlir::shape::MulOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { + // SizeType is compatible with IndexType. + return eachHasOnlyOneOfTypes(l, r); +} //===----------------------------------------------------------------------===// // ShapeOfOp //===----------------------------------------------------------------------===// @@ -1251,18 +1426,6 @@ OpFoldResult ShapeOfOp::fold(ArrayRef) { return builder.getIndexTensorAttr(type.getShape()); } -void ShapeOfOp::build(OpBuilder &builder, OperationState &result, Value arg) { - if (auto shapedTy = arg.getType().dyn_cast()) { - int64_t rank = - shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize; - Type indexTy = builder.getIndexType(); - Type extentTensorTy = RankedTensorType::get({rank}, indexTy); - return ShapeOfOp::build(builder, result, extentTensorTy, arg); - } - Type shapeTy = builder.getType(); - return ShapeOfOp::build(builder, result, shapeTy, arg); -} - namespace { struct ShapeOfWithTensor : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -1317,6 +1480,44 @@ void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns, patterns.add(context); } +LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + if (operands[0].getType().isa()) + inferredReturnTypes.assign({ShapeType::get(context)}); + else { + auto shapedTy = operands[0].getType().cast(); + int64_t rank = + shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize; + Type indexTy = IndexType::get(context); + Type extentTensorTy = RankedTensorType::get({rank}, indexTy); + inferredReturnTypes.assign({extentTensorTy}); + } + return success(); +} + +bool mlir::shape::ShapeOfOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { + if (l.size() != 1 || r.size() != 1) + return false; + if (l == r) + return true; + + Type lhs = l.front(); + Type rhs = r.front(); + + if (!lhs.isa() || !rhs.isa()) + return false; + + if (lhs.isa() || rhs.isa()) + // Shape type is compatible with all other valid return types. + return true; + + if (succeeded(verifyCompatibleShapes({lhs, rhs}))) + return true; + return false; +} + //===----------------------------------------------------------------------===// // SizeToIndexOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/invalid.mlir b/mlir/test/Dialect/Shape/invalid.mlir index c605e25b3873..030926a9cce4 100644 --- a/mlir/test/Dialect/Shape/invalid.mlir +++ b/mlir/test/Dialect/Shape/invalid.mlir @@ -97,6 +97,14 @@ func @shape_of(%value_arg : !shape.value_shape, // ----- +func @shape_of_incompatible_return_types(%value_arg : tensor<1x2xindex>) { + // expected-error@+1 {{'shape.shape_of' op inferred type(s) 'tensor<2xindex>' are incompatible with return type(s) of operation 'tensor<3xf32>'}} + %0 = shape.shape_of %value_arg : tensor<1x2xindex> -> tensor<3xf32> + return +} + +// ----- + func @rank(%arg : !shape.shape) { // expected-error@+1 {{if at least one of the operands can hold error values then the result must be of type `size` to propagate them}} %0 = shape.rank %arg : !shape.shape -> index