Enables inferring return types for Shape op if possible

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D102565
This commit is contained in:
Chia-hung Duan 2021-08-18 20:46:26 +00:00
parent c22b64ef66
commit 41e5dbe0fa
4 changed files with 316 additions and 37 deletions

View File

@ -28,7 +28,9 @@ include "mlir/IR/SymbolInterfaces.td"
class Shape_Op<string mnemonic, list<OpTrait> traits = []> : class Shape_Op<string mnemonic, list<OpTrait> traits = []> :
Op<ShapeDialect, mnemonic, traits>; Op<ShapeDialect, mnemonic, traits>;
def Shape_AddOp : Shape_Op<"add", [Commutative, NoSideEffect]> { def Shape_AddOp : Shape_Op<"add",
[Commutative, NoSideEffect,
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "Addition of sizes and indices"; let summary = "Addition of sizes and indices";
let description = [{ let description = [{
Adds two sizes or indices. If either operand is an error it will be Adds two sizes or indices. If either operand is an error it will be
@ -47,6 +49,12 @@ def Shape_AddOp : Shape_Op<"add", [Commutative, NoSideEffect]> {
}]; }];
let verifier = [{ return verifySizeOrIndexOp(*this); }]; let verifier = [{ return verifySizeOrIndexOp(*this); }];
let extraClassDeclaration = [{
// Returns when two result types are compatible for this op; method used by
// InferTypeOpInterface
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
}];
} }
def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative, NoSideEffect]> { def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative, NoSideEffect]> {
@ -77,6 +85,8 @@ def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative, NoSideEffect]> {
OptionalAttr<StrAttr>:$error); OptionalAttr<StrAttr>:$error);
let results = (outs Shape_ShapeOrExtentTensorType:$result); let results = (outs Shape_ShapeOrExtentTensorType:$result);
let builders = [OpBuilder<(ins "Value":$shape)>];
let assemblyFormat = [{ let assemblyFormat = [{
$shapes attr-dict `:` type($shapes) `->` type($result) $shapes attr-dict `:` type($shapes) `->` type($result)
}]; }];
@ -145,7 +155,8 @@ def Shape_ConstSizeOp : Shape_Op<"const_size", [
let hasFolder = 1; let hasFolder = 1;
} }
def Shape_DivOp : Shape_Op<"div", [NoSideEffect]> { def Shape_DivOp : Shape_Op<"div", [NoSideEffect,
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "Division of sizes and indices"; let summary = "Division of sizes and indices";
let description = [{ let description = [{
Divides two sizes or indices. If either operand is an error it will be Divides two sizes or indices. If either operand is an error it will be
@ -173,10 +184,16 @@ def Shape_DivOp : Shape_Op<"div", [NoSideEffect]> {
let verifier = [{ return ::verifySizeOrIndexOp(*this); }]; let verifier = [{ return ::verifySizeOrIndexOp(*this); }];
let hasFolder = 1; let hasFolder = 1;
let extraClassDeclaration = [{
// Returns when two result types are compatible for this op; method used by
// InferTypeOpInterface
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
}];
} }
def Shape_ShapeEqOp : Shape_Op<"shape_eq", [NoSideEffect, Commutative, def Shape_ShapeEqOp : Shape_Op<"shape_eq",
InferTypeOpInterface]> { [NoSideEffect, Commutative, InferTypeOpInterface]> {
let summary = "Returns whether the input shapes or extent tensors are equal"; let summary = "Returns whether the input shapes or extent tensors are equal";
let description = [{ let description = [{
Takes one or more shape or extent tensor operands and determines whether Takes one or more shape or extent tensor operands and determines whether
@ -290,7 +307,8 @@ def Shape_IsBroadcastableOp : Shape_Op<"is_broadcastable",
let assemblyFormat = "$shapes attr-dict `:` type($shapes)"; let assemblyFormat = "$shapes attr-dict `:` type($shapes)";
} }
def Shape_RankOp : Shape_Op<"rank", [NoSideEffect]> { def Shape_RankOp : Shape_Op<"rank",
[NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "Gets the rank of a shape"; let summary = "Gets the rank of a shape";
let description = [{ let description = [{
Returns the rank of the shape or extent tensor, i.e. the number of extents. Returns the rank of the shape or extent tensor, i.e. the number of extents.
@ -304,6 +322,12 @@ def Shape_RankOp : Shape_Op<"rank", [NoSideEffect]> {
let hasFolder = 1; let hasFolder = 1;
let hasCanonicalizer = 1; let hasCanonicalizer = 1;
let verifier = [{ return ::verifySizeOrIndexOp(*this); }]; let verifier = [{ return ::verifySizeOrIndexOp(*this); }];
let extraClassDeclaration = [{
// Returns when two result types are compatible for this op; method used by
// InferTypeOpInterface
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
}];
} }
def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [NoSideEffect]> { def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [NoSideEffect]> {
@ -324,7 +348,8 @@ def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [NoSideEffect]> {
let hasFolder = 1; let hasFolder = 1;
} }
def Shape_GetExtentOp : Shape_Op<"get_extent", [NoSideEffect]> { def Shape_GetExtentOp : Shape_Op<"get_extent",
[NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "Gets the specified extent from a shape or extent tensor"; let summary = "Gets the specified extent from a shape or extent tensor";
let description = [{ let description = [{
Gets the extent indexed by `dim` from the `shape` operand. If the shape is Gets the extent indexed by `dim` from the `shape` operand. If the shape is
@ -344,6 +369,9 @@ def Shape_GetExtentOp : Shape_Op<"get_extent", [NoSideEffect]> {
let extraClassDeclaration = [{ let extraClassDeclaration = [{
/// Get the `dim` value as integer if it is constant. /// Get the `dim` value as integer if it is constant.
Optional<int64_t> getConstantDim(); Optional<int64_t> getConstantDim();
/// Returns when two result types are compatible for this op; method used by
/// InferTypeOpInterface
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
}]; }];
let hasFolder = 1; let hasFolder = 1;
@ -369,7 +397,8 @@ def Shape_IndexToSizeOp : Shape_Op<"index_to_size", [NoSideEffect]> {
let hasCanonicalizer = 1; let hasCanonicalizer = 1;
} }
def Shape_JoinOp : Shape_Op<"join", [Commutative]> { def Shape_JoinOp : Shape_Op<"join",
[Commutative, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "Returns the least general shape.shape of its operands"; let summary = "Returns the least general shape.shape of its operands";
let description = [{ let description = [{
An operation that computes the least general shape of input operands. An operation that computes the least general shape of input operands.
@ -405,9 +434,17 @@ def Shape_JoinOp : Shape_Op<"join", [Commutative]> {
$arg0 `,` $arg1 (`,` `error` `=` $error^)? attr-dict `:` $arg0 `,` $arg1 (`,` `error` `=` $error^)? attr-dict `:`
type($arg0) `,` type($arg1) `->` type($result) type($arg0) `,` type($arg1) `->` type($result)
}]; }];
let extraClassDeclaration = [{
// Returns when two result types are compatible for this op; method used by
// InferTypeOpInterface
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
}];
} }
def Shape_MaxOp : Shape_Op<"max", [Commutative, NoSideEffect]> { def Shape_MaxOp : Shape_Op<"max",
[Commutative, NoSideEffect,
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "Elementwise maximum"; let summary = "Elementwise maximum";
let description = [{ let description = [{
Computes the elementwise maximum of two sizes or shapes with equal ranks. Computes the elementwise maximum of two sizes or shapes with equal ranks.
@ -424,9 +461,17 @@ def Shape_MaxOp : Shape_Op<"max", [Commutative, NoSideEffect]> {
}]; }];
let hasFolder = 1; let hasFolder = 1;
let extraClassDeclaration = [{
// Returns when two result types are compatible for this op; method used by
// InferTypeOpInterface
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
}];
} }
def Shape_MinOp : Shape_Op<"min", [Commutative, NoSideEffect]> { def Shape_MinOp : Shape_Op<"min",
[Commutative, NoSideEffect,
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "Elementwise minimum"; let summary = "Elementwise minimum";
let description = [{ let description = [{
Computes the elementwise minimum of two sizes or shapes with equal ranks. Computes the elementwise minimum of two sizes or shapes with equal ranks.
@ -443,9 +488,17 @@ def Shape_MinOp : Shape_Op<"min", [Commutative, NoSideEffect]> {
}]; }];
let hasFolder = 1; let hasFolder = 1;
let extraClassDeclaration = [{
// Returns when two result types are compatible for this op; method used by
// InferTypeOpInterface
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
}];
} }
def Shape_MulOp : Shape_Op<"mul", [Commutative, NoSideEffect]> { def Shape_MulOp : Shape_Op<"mul",
[Commutative, NoSideEffect,
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "Multiplication of sizes and indices"; let summary = "Multiplication of sizes and indices";
let description = [{ let description = [{
Multiplies two sizes or indices. If either operand is an error it will be Multiplies two sizes or indices. If either operand is an error it will be
@ -465,9 +518,16 @@ def Shape_MulOp : Shape_Op<"mul", [Commutative, NoSideEffect]> {
let verifier = [{ return ::verifySizeOrIndexOp(*this); }]; let verifier = [{ return ::verifySizeOrIndexOp(*this); }];
let hasFolder = 1; let hasFolder = 1;
let extraClassDeclaration = [{
// Returns when two result types are compatible for this op; method used by
// InferTypeOpInterface
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
}];
} }
def Shape_NumElementsOp : Shape_Op<"num_elements", [NoSideEffect]> { def Shape_NumElementsOp : Shape_Op<"num_elements",
[NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "Returns the number of elements for a given shape"; let summary = "Returns the number of elements for a given shape";
let description = [{ let description = [{
Returns the number of elements for a given shape which is the product of its Returns the number of elements for a given shape which is the product of its
@ -480,12 +540,15 @@ def Shape_NumElementsOp : Shape_Op<"num_elements", [NoSideEffect]> {
let arguments = (ins Shape_ShapeOrExtentTensorType:$shape); let arguments = (ins Shape_ShapeOrExtentTensorType:$shape);
let results = (outs Shape_SizeOrIndexType:$result); let results = (outs Shape_SizeOrIndexType:$result);
let builders = [OpBuilder<(ins "Value":$shape)>];
let assemblyFormat = "$shape attr-dict `:` type($shape) `->` type($result)"; let assemblyFormat = "$shape attr-dict `:` type($shape) `->` type($result)";
let hasFolder = 1; let hasFolder = 1;
let verifier = [{ return ::verifySizeOrIndexOp(*this); }]; let verifier = [{ return ::verifySizeOrIndexOp(*this); }];
let extraClassDeclaration = [{
// Returns when two result types are compatible for this op; method used by
// InferTypeOpInterface
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
}];
} }
def Shape_ReduceOp : Shape_Op<"reduce", def Shape_ReduceOp : Shape_Op<"reduce",
@ -535,7 +598,8 @@ def Shape_ReduceOp : Shape_Op<"reduce",
let parser = [{ return ::parse$cppClass(parser, result); }]; let parser = [{ return ::parse$cppClass(parser, result); }];
} }
def Shape_ShapeOfOp : Shape_Op<"shape_of", [NoSideEffect]> { def Shape_ShapeOfOp : Shape_Op<"shape_of",
[NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "Returns shape of a value or shaped type operand"; let summary = "Returns shape of a value or shaped type operand";
let description = [{ let description = [{
@ -548,11 +612,15 @@ def Shape_ShapeOfOp : Shape_Op<"shape_of", [NoSideEffect]> {
let assemblyFormat = "$arg attr-dict `:` type($arg) `->` type($result)"; let assemblyFormat = "$arg attr-dict `:` type($arg) `->` type($result)";
let builders = [OpBuilder<(ins "Value":$arg)>];
let verifier = [{ return ::verifyShapeOrExtentTensorOp(*this); }]; let verifier = [{ return ::verifyShapeOrExtentTensorOp(*this); }];
let hasCanonicalizer = 1; let hasCanonicalizer = 1;
let hasFolder = 1; let hasFolder = 1;
let extraClassDeclaration = [{
// Returns when two result types are compatible for this op; method used by
// InferTypeOpInterface
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
}];
} }
def Shape_SizeToIndexOp : Shape_Op<"size_to_index", [NoSideEffect]> { def Shape_SizeToIndexOp : Shape_Op<"size_to_index", [NoSideEffect]> {

View File

@ -34,7 +34,9 @@ def InferTypeOpInterface : OpInterface<"InferTypeOpInterface"> {
The method takes an optional location which, if set, will be used to The method takes an optional location which, if set, will be used to
report errors on. The operands and attributes correspond to those with report errors on. The operands and attributes correspond to those with
which an Operation would be created (e.g., as used in Operation::create) which an Operation would be created (e.g., as used in Operation::create)
and the regions of the op. and the regions of the op. Be aware that this method is supposed to be
called with valid arguments, e.g., operands are verified, or it may result
in an undefined behavior.
}], }],
/*retTy=*/"::mlir::LogicalResult", /*retTy=*/"::mlir::LogicalResult",
/*methodName=*/"inferReturnTypes", /*methodName=*/"inferReturnTypes",

View File

@ -89,6 +89,16 @@ static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
return success(); return success();
} }
template <typename... Ty>
static bool eachHasOnlyOneOfTypes(TypeRange typeRange) {
return typeRange.size() == 1 && typeRange.front().isa<Ty...>();
}
template <typename... Ty, typename... ranges>
static bool eachHasOnlyOneOfTypes(TypeRange l, ranges... rs) {
return eachHasOnlyOneOfTypes<Ty...>(l) && eachHasOnlyOneOfTypes<Ty...>(rs...);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// InlinerInterface // InlinerInterface
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -404,6 +414,27 @@ void AssumingOp::build(
result.addTypes(assumingTypes); result.addTypes(assumingTypes);
} }
//===----------------------------------------------------------------------===//
// AddOp
//===----------------------------------------------------------------------===//
LogicalResult mlir::shape::AddOp::inferReturnTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (operands[0].getType().isa<SizeType>() ||
operands[1].getType().isa<SizeType>())
inferredReturnTypes.assign({SizeType::get(context)});
else
inferredReturnTypes.assign({IndexType::get(context)});
return success();
}
bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
// SizeType is compatible with IndexType.
return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AssumingAllOp // AssumingAllOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -955,6 +986,23 @@ OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
return IntegerAttr::get(indexTy, quotient); return IntegerAttr::get(indexTy, quotient);
} }
LogicalResult mlir::shape::DivOp::inferReturnTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (operands[0].getType().isa<SizeType>() ||
operands[1].getType().isa<SizeType>())
inferredReturnTypes.assign({SizeType::get(context)});
else
inferredReturnTypes.assign({IndexType::get(context)});
return success();
}
bool mlir::shape::DivOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
// SizeType is compatible with IndexType.
return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// ShapeEqOp // ShapeEqOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1096,6 +1144,20 @@ void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
} }
} }
LogicalResult mlir::shape::GetExtentOp::inferReturnTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
inferredReturnTypes.assign({IndexType::get(context)});
return success();
}
bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(TypeRange l,
TypeRange r) {
// SizeType is compatible with IndexType.
return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// IsBroadcastableOp // IsBroadcastableOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1114,6 +1176,38 @@ OpFoldResult IsBroadcastableOp::fold(ArrayRef<Attribute> operands) {
return nullptr; return nullptr;
} }
//===----------------------------------------------------------------------===//
// JoinOp
//===----------------------------------------------------------------------===//
LogicalResult mlir::shape::JoinOp::inferReturnTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
inferredReturnTypes.assign({operands[0].getType()});
return success();
}
bool mlir::shape::JoinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
if (l.size() != 1 || r.size() != 1)
return false;
if (l == r)
return true;
Type lhs = l.front();
Type rhs = r.front();
if (lhs != rhs)
return false;
if (lhs.isa<SizeType>() || lhs.isa<ShapeType>())
return true;
if (succeeded(verifyCompatibleShapes({lhs, rhs})))
return true;
return false;
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// RankOp // RankOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1173,6 +1267,22 @@ void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<RankShapeOfCanonicalizationPattern>(context); patterns.add<RankShapeOfCanonicalizationPattern>(context);
} }
LogicalResult mlir::shape::RankOp::inferReturnTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (operands[0].getType().isa<ShapeType>())
inferredReturnTypes.assign({SizeType::get(context)});
else
inferredReturnTypes.assign({IndexType::get(context)});
return success();
}
bool mlir::shape::RankOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
// SizeType is compatible with IndexType.
return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// NumElementsOp // NumElementsOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1191,14 +1301,21 @@ OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) {
return builder.getIndexAttr(product.getLimitedValue()); return builder.getIndexAttr(product.getLimitedValue());
} }
void NumElementsOp::build(OpBuilder &builder, OperationState &result, LogicalResult mlir::shape::NumElementsOp::inferReturnTypes(
Value shape) { MLIRContext *context, Optional<Location> location, ValueRange operands,
if (shape.getType().isa<ShapedType>()) { DictionaryAttr attributes, RegionRange regions,
auto type = builder.getIndexType(); SmallVectorImpl<Type> &inferredReturnTypes) {
return build(builder, result, type, shape); if (operands[0].getType().isa<ShapeType>())
} inferredReturnTypes.assign({SizeType::get(context)});
auto type = SizeType::get(builder.getContext()); else
return build(builder, result, type, shape); inferredReturnTypes.assign({IndexType::get(context)});
return success();
}
bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(TypeRange l,
TypeRange r) {
// SizeType is compatible with IndexType.
return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1212,6 +1329,27 @@ OpFoldResult MaxOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
return nullptr; return nullptr;
} }
LogicalResult mlir::shape::MaxOp::inferReturnTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (operands[0].getType() == operands[1].getType())
inferredReturnTypes.assign({operands[0].getType()});
else
inferredReturnTypes.assign({SizeType::get(context)});
return success();
}
bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
if (l.size() != 1 || r.size() != 1)
return false;
if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>())
return true;
if (l.front().isa<SizeType>() && r.front().isa<SizeType>())
return true;
return false;
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// MinOp // MinOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1223,6 +1361,27 @@ OpFoldResult MinOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
return nullptr; return nullptr;
} }
LogicalResult mlir::shape::MinOp::inferReturnTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (operands[0].getType() == operands[1].getType())
inferredReturnTypes.assign({operands[0].getType()});
else
inferredReturnTypes.assign({SizeType::get(context)});
return success();
}
bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
if (l.size() != 1 || r.size() != 1)
return false;
if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>())
return true;
if (l.front().isa<SizeType>() && r.front().isa<SizeType>())
return true;
return false;
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// MulOp // MulOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1239,6 +1398,22 @@ OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
return IntegerAttr::get(indexTy, folded); return IntegerAttr::get(indexTy, folded);
} }
LogicalResult mlir::shape::MulOp::inferReturnTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (operands[0].getType().isa<SizeType>() ||
operands[1].getType().isa<SizeType>())
inferredReturnTypes.assign({SizeType::get(context)});
else
inferredReturnTypes.assign({IndexType::get(context)});
return success();
}
bool mlir::shape::MulOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
// SizeType is compatible with IndexType.
return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// ShapeOfOp // ShapeOfOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1251,18 +1426,6 @@ OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
return builder.getIndexTensorAttr(type.getShape()); return builder.getIndexTensorAttr(type.getShape());
} }
void ShapeOfOp::build(OpBuilder &builder, OperationState &result, Value arg) {
if (auto shapedTy = arg.getType().dyn_cast<ShapedType>()) {
int64_t rank =
shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize;
Type indexTy = builder.getIndexType();
Type extentTensorTy = RankedTensorType::get({rank}, indexTy);
return ShapeOfOp::build(builder, result, extentTensorTy, arg);
}
Type shapeTy = builder.getType<ShapeType>();
return ShapeOfOp::build(builder, result, shapeTy, arg);
}
namespace { namespace {
struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> { struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern; using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
@ -1317,6 +1480,44 @@ void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor>(context); patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor>(context);
} }
LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (operands[0].getType().isa<ValueShapeType>())
inferredReturnTypes.assign({ShapeType::get(context)});
else {
auto shapedTy = operands[0].getType().cast<ShapedType>();
int64_t rank =
shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize;
Type indexTy = IndexType::get(context);
Type extentTensorTy = RankedTensorType::get({rank}, indexTy);
inferredReturnTypes.assign({extentTensorTy});
}
return success();
}
bool mlir::shape::ShapeOfOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
if (l.size() != 1 || r.size() != 1)
return false;
if (l == r)
return true;
Type lhs = l.front();
Type rhs = r.front();
if (!lhs.isa<ShapeType, ShapedType>() || !rhs.isa<ShapeType, ShapedType>())
return false;
if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>())
// Shape type is compatible with all other valid return types.
return true;
if (succeeded(verifyCompatibleShapes({lhs, rhs})))
return true;
return false;
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// SizeToIndexOp // SizeToIndexOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -97,6 +97,14 @@ func @shape_of(%value_arg : !shape.value_shape,
// ----- // -----
func @shape_of_incompatible_return_types(%value_arg : tensor<1x2xindex>) {
// expected-error@+1 {{'shape.shape_of' op inferred type(s) 'tensor<2xindex>' are incompatible with return type(s) of operation 'tensor<3xf32>'}}
%0 = shape.shape_of %value_arg : tensor<1x2xindex> -> tensor<3xf32>
return
}
// -----
func @rank(%arg : !shape.shape) { func @rank(%arg : !shape.shape) {
// expected-error@+1 {{if at least one of the operands can hold error values then the result must be of type `size` to propagate them}} // expected-error@+1 {{if at least one of the operands can hold error values then the result must be of type `size` to propagate them}}
%0 = shape.rank %arg : !shape.shape -> index %0 = shape.rank %arg : !shape.shape -> index