forked from OSchip/llvm-project
Enables inferring return types for Shape op if possible
Reviewed By: jpienaar Differential Revision: https://reviews.llvm.org/D102565
This commit is contained in:
parent
c22b64ef66
commit
41e5dbe0fa
|
@ -28,7 +28,9 @@ include "mlir/IR/SymbolInterfaces.td"
|
|||
class Shape_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<ShapeDialect, mnemonic, traits>;
|
||||
|
||||
def Shape_AddOp : Shape_Op<"add", [Commutative, NoSideEffect]> {
|
||||
def Shape_AddOp : Shape_Op<"add",
|
||||
[Commutative, NoSideEffect,
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||
let summary = "Addition of sizes and indices";
|
||||
let description = [{
|
||||
Adds two sizes or indices. If either operand is an error it will be
|
||||
|
@ -47,6 +49,12 @@ def Shape_AddOp : Shape_Op<"add", [Commutative, NoSideEffect]> {
|
|||
}];
|
||||
|
||||
let verifier = [{ return verifySizeOrIndexOp(*this); }];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// Returns when two result types are compatible for this op; method used by
|
||||
// InferTypeOpInterface
|
||||
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
|
||||
}];
|
||||
}
|
||||
|
||||
def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative, NoSideEffect]> {
|
||||
|
@ -77,6 +85,8 @@ def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative, NoSideEffect]> {
|
|||
OptionalAttr<StrAttr>:$error);
|
||||
let results = (outs Shape_ShapeOrExtentTensorType:$result);
|
||||
|
||||
let builders = [OpBuilder<(ins "Value":$shape)>];
|
||||
|
||||
let assemblyFormat = [{
|
||||
$shapes attr-dict `:` type($shapes) `->` type($result)
|
||||
}];
|
||||
|
@ -145,7 +155,8 @@ def Shape_ConstSizeOp : Shape_Op<"const_size", [
|
|||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Shape_DivOp : Shape_Op<"div", [NoSideEffect]> {
|
||||
def Shape_DivOp : Shape_Op<"div", [NoSideEffect,
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||
let summary = "Division of sizes and indices";
|
||||
let description = [{
|
||||
Divides two sizes or indices. If either operand is an error it will be
|
||||
|
@ -173,10 +184,16 @@ def Shape_DivOp : Shape_Op<"div", [NoSideEffect]> {
|
|||
|
||||
let verifier = [{ return ::verifySizeOrIndexOp(*this); }];
|
||||
let hasFolder = 1;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// Returns when two result types are compatible for this op; method used by
|
||||
// InferTypeOpInterface
|
||||
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
|
||||
}];
|
||||
}
|
||||
|
||||
def Shape_ShapeEqOp : Shape_Op<"shape_eq", [NoSideEffect, Commutative,
|
||||
InferTypeOpInterface]> {
|
||||
def Shape_ShapeEqOp : Shape_Op<"shape_eq",
|
||||
[NoSideEffect, Commutative, InferTypeOpInterface]> {
|
||||
let summary = "Returns whether the input shapes or extent tensors are equal";
|
||||
let description = [{
|
||||
Takes one or more shape or extent tensor operands and determines whether
|
||||
|
@ -290,7 +307,8 @@ def Shape_IsBroadcastableOp : Shape_Op<"is_broadcastable",
|
|||
let assemblyFormat = "$shapes attr-dict `:` type($shapes)";
|
||||
}
|
||||
|
||||
def Shape_RankOp : Shape_Op<"rank", [NoSideEffect]> {
|
||||
def Shape_RankOp : Shape_Op<"rank",
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||
let summary = "Gets the rank of a shape";
|
||||
let description = [{
|
||||
Returns the rank of the shape or extent tensor, i.e. the number of extents.
|
||||
|
@ -304,6 +322,12 @@ def Shape_RankOp : Shape_Op<"rank", [NoSideEffect]> {
|
|||
let hasFolder = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
let verifier = [{ return ::verifySizeOrIndexOp(*this); }];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// Returns when two result types are compatible for this op; method used by
|
||||
// InferTypeOpInterface
|
||||
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
|
||||
}];
|
||||
}
|
||||
|
||||
def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [NoSideEffect]> {
|
||||
|
@ -324,7 +348,8 @@ def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [NoSideEffect]> {
|
|||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Shape_GetExtentOp : Shape_Op<"get_extent", [NoSideEffect]> {
|
||||
def Shape_GetExtentOp : Shape_Op<"get_extent",
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||
let summary = "Gets the specified extent from a shape or extent tensor";
|
||||
let description = [{
|
||||
Gets the extent indexed by `dim` from the `shape` operand. If the shape is
|
||||
|
@ -344,6 +369,9 @@ def Shape_GetExtentOp : Shape_Op<"get_extent", [NoSideEffect]> {
|
|||
let extraClassDeclaration = [{
|
||||
/// Get the `dim` value as integer if it is constant.
|
||||
Optional<int64_t> getConstantDim();
|
||||
/// Returns when two result types are compatible for this op; method used by
|
||||
/// InferTypeOpInterface
|
||||
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
|
||||
}];
|
||||
|
||||
let hasFolder = 1;
|
||||
|
@ -369,7 +397,8 @@ def Shape_IndexToSizeOp : Shape_Op<"index_to_size", [NoSideEffect]> {
|
|||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Shape_JoinOp : Shape_Op<"join", [Commutative]> {
|
||||
def Shape_JoinOp : Shape_Op<"join",
|
||||
[Commutative, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||
let summary = "Returns the least general shape.shape of its operands";
|
||||
let description = [{
|
||||
An operation that computes the least general shape of input operands.
|
||||
|
@ -405,9 +434,17 @@ def Shape_JoinOp : Shape_Op<"join", [Commutative]> {
|
|||
$arg0 `,` $arg1 (`,` `error` `=` $error^)? attr-dict `:`
|
||||
type($arg0) `,` type($arg1) `->` type($result)
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// Returns when two result types are compatible for this op; method used by
|
||||
// InferTypeOpInterface
|
||||
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
|
||||
}];
|
||||
}
|
||||
|
||||
def Shape_MaxOp : Shape_Op<"max", [Commutative, NoSideEffect]> {
|
||||
def Shape_MaxOp : Shape_Op<"max",
|
||||
[Commutative, NoSideEffect,
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||
let summary = "Elementwise maximum";
|
||||
let description = [{
|
||||
Computes the elementwise maximum of two sizes or shapes with equal ranks.
|
||||
|
@ -424,9 +461,17 @@ def Shape_MaxOp : Shape_Op<"max", [Commutative, NoSideEffect]> {
|
|||
}];
|
||||
|
||||
let hasFolder = 1;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// Returns when two result types are compatible for this op; method used by
|
||||
// InferTypeOpInterface
|
||||
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
|
||||
}];
|
||||
}
|
||||
|
||||
def Shape_MinOp : Shape_Op<"min", [Commutative, NoSideEffect]> {
|
||||
def Shape_MinOp : Shape_Op<"min",
|
||||
[Commutative, NoSideEffect,
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||
let summary = "Elementwise minimum";
|
||||
let description = [{
|
||||
Computes the elementwise minimum of two sizes or shapes with equal ranks.
|
||||
|
@ -443,9 +488,17 @@ def Shape_MinOp : Shape_Op<"min", [Commutative, NoSideEffect]> {
|
|||
}];
|
||||
|
||||
let hasFolder = 1;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// Returns when two result types are compatible for this op; method used by
|
||||
// InferTypeOpInterface
|
||||
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
|
||||
}];
|
||||
}
|
||||
|
||||
def Shape_MulOp : Shape_Op<"mul", [Commutative, NoSideEffect]> {
|
||||
def Shape_MulOp : Shape_Op<"mul",
|
||||
[Commutative, NoSideEffect,
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||
let summary = "Multiplication of sizes and indices";
|
||||
let description = [{
|
||||
Multiplies two sizes or indices. If either operand is an error it will be
|
||||
|
@ -465,9 +518,16 @@ def Shape_MulOp : Shape_Op<"mul", [Commutative, NoSideEffect]> {
|
|||
|
||||
let verifier = [{ return ::verifySizeOrIndexOp(*this); }];
|
||||
let hasFolder = 1;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// Returns when two result types are compatible for this op; method used by
|
||||
// InferTypeOpInterface
|
||||
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
|
||||
}];
|
||||
}
|
||||
|
||||
def Shape_NumElementsOp : Shape_Op<"num_elements", [NoSideEffect]> {
|
||||
def Shape_NumElementsOp : Shape_Op<"num_elements",
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||
let summary = "Returns the number of elements for a given shape";
|
||||
let description = [{
|
||||
Returns the number of elements for a given shape which is the product of its
|
||||
|
@ -480,12 +540,15 @@ def Shape_NumElementsOp : Shape_Op<"num_elements", [NoSideEffect]> {
|
|||
let arguments = (ins Shape_ShapeOrExtentTensorType:$shape);
|
||||
let results = (outs Shape_SizeOrIndexType:$result);
|
||||
|
||||
let builders = [OpBuilder<(ins "Value":$shape)>];
|
||||
|
||||
let assemblyFormat = "$shape attr-dict `:` type($shape) `->` type($result)";
|
||||
|
||||
let hasFolder = 1;
|
||||
let verifier = [{ return ::verifySizeOrIndexOp(*this); }];
|
||||
let extraClassDeclaration = [{
|
||||
// Returns when two result types are compatible for this op; method used by
|
||||
// InferTypeOpInterface
|
||||
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
|
||||
}];
|
||||
}
|
||||
|
||||
def Shape_ReduceOp : Shape_Op<"reduce",
|
||||
|
@ -535,7 +598,8 @@ def Shape_ReduceOp : Shape_Op<"reduce",
|
|||
let parser = [{ return ::parse$cppClass(parser, result); }];
|
||||
}
|
||||
|
||||
def Shape_ShapeOfOp : Shape_Op<"shape_of", [NoSideEffect]> {
|
||||
def Shape_ShapeOfOp : Shape_Op<"shape_of",
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||
let summary = "Returns shape of a value or shaped type operand";
|
||||
|
||||
let description = [{
|
||||
|
@ -548,11 +612,15 @@ def Shape_ShapeOfOp : Shape_Op<"shape_of", [NoSideEffect]> {
|
|||
|
||||
let assemblyFormat = "$arg attr-dict `:` type($arg) `->` type($result)";
|
||||
|
||||
let builders = [OpBuilder<(ins "Value":$arg)>];
|
||||
|
||||
let verifier = [{ return ::verifyShapeOrExtentTensorOp(*this); }];
|
||||
let hasCanonicalizer = 1;
|
||||
let hasFolder = 1;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// Returns when two result types are compatible for this op; method used by
|
||||
// InferTypeOpInterface
|
||||
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
|
||||
}];
|
||||
}
|
||||
|
||||
def Shape_SizeToIndexOp : Shape_Op<"size_to_index", [NoSideEffect]> {
|
||||
|
|
|
@ -34,7 +34,9 @@ def InferTypeOpInterface : OpInterface<"InferTypeOpInterface"> {
|
|||
The method takes an optional location which, if set, will be used to
|
||||
report errors on. The operands and attributes correspond to those with
|
||||
which an Operation would be created (e.g., as used in Operation::create)
|
||||
and the regions of the op.
|
||||
and the regions of the op. Be aware that this method is supposed to be
|
||||
called with valid arguments, e.g., operands are verified, or it may result
|
||||
in an undefined behavior.
|
||||
}],
|
||||
/*retTy=*/"::mlir::LogicalResult",
|
||||
/*methodName=*/"inferReturnTypes",
|
||||
|
|
|
@ -89,6 +89,16 @@ static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
template <typename... Ty>
|
||||
static bool eachHasOnlyOneOfTypes(TypeRange typeRange) {
|
||||
return typeRange.size() == 1 && typeRange.front().isa<Ty...>();
|
||||
}
|
||||
|
||||
template <typename... Ty, typename... ranges>
|
||||
static bool eachHasOnlyOneOfTypes(TypeRange l, ranges... rs) {
|
||||
return eachHasOnlyOneOfTypes<Ty...>(l) && eachHasOnlyOneOfTypes<Ty...>(rs...);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// InlinerInterface
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -404,6 +414,27 @@ void AssumingOp::build(
|
|||
result.addTypes(assumingTypes);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AddOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult mlir::shape::AddOp::inferReturnTypes(
|
||||
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
if (operands[0].getType().isa<SizeType>() ||
|
||||
operands[1].getType().isa<SizeType>())
|
||||
inferredReturnTypes.assign({SizeType::get(context)});
|
||||
else
|
||||
inferredReturnTypes.assign({IndexType::get(context)});
|
||||
return success();
|
||||
}
|
||||
|
||||
bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
|
||||
// SizeType is compatible with IndexType.
|
||||
return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AssumingAllOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -955,6 +986,23 @@ OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
|
|||
return IntegerAttr::get(indexTy, quotient);
|
||||
}
|
||||
|
||||
LogicalResult mlir::shape::DivOp::inferReturnTypes(
|
||||
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
if (operands[0].getType().isa<SizeType>() ||
|
||||
operands[1].getType().isa<SizeType>())
|
||||
inferredReturnTypes.assign({SizeType::get(context)});
|
||||
else
|
||||
inferredReturnTypes.assign({IndexType::get(context)});
|
||||
return success();
|
||||
}
|
||||
|
||||
bool mlir::shape::DivOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
|
||||
// SizeType is compatible with IndexType.
|
||||
return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ShapeEqOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1096,6 +1144,20 @@ void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
|
|||
}
|
||||
}
|
||||
|
||||
LogicalResult mlir::shape::GetExtentOp::inferReturnTypes(
|
||||
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
inferredReturnTypes.assign({IndexType::get(context)});
|
||||
return success();
|
||||
}
|
||||
|
||||
bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(TypeRange l,
|
||||
TypeRange r) {
|
||||
// SizeType is compatible with IndexType.
|
||||
return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// IsBroadcastableOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1114,6 +1176,38 @@ OpFoldResult IsBroadcastableOp::fold(ArrayRef<Attribute> operands) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// JoinOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult mlir::shape::JoinOp::inferReturnTypes(
|
||||
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
inferredReturnTypes.assign({operands[0].getType()});
|
||||
return success();
|
||||
}
|
||||
|
||||
bool mlir::shape::JoinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
|
||||
if (l.size() != 1 || r.size() != 1)
|
||||
return false;
|
||||
if (l == r)
|
||||
return true;
|
||||
|
||||
Type lhs = l.front();
|
||||
Type rhs = r.front();
|
||||
|
||||
if (lhs != rhs)
|
||||
return false;
|
||||
|
||||
if (lhs.isa<SizeType>() || lhs.isa<ShapeType>())
|
||||
return true;
|
||||
|
||||
if (succeeded(verifyCompatibleShapes({lhs, rhs})))
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// RankOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1173,6 +1267,22 @@ void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|||
patterns.add<RankShapeOfCanonicalizationPattern>(context);
|
||||
}
|
||||
|
||||
LogicalResult mlir::shape::RankOp::inferReturnTypes(
|
||||
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
if (operands[0].getType().isa<ShapeType>())
|
||||
inferredReturnTypes.assign({SizeType::get(context)});
|
||||
else
|
||||
inferredReturnTypes.assign({IndexType::get(context)});
|
||||
return success();
|
||||
}
|
||||
|
||||
bool mlir::shape::RankOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
|
||||
// SizeType is compatible with IndexType.
|
||||
return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// NumElementsOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1191,14 +1301,21 @@ OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) {
|
|||
return builder.getIndexAttr(product.getLimitedValue());
|
||||
}
|
||||
|
||||
void NumElementsOp::build(OpBuilder &builder, OperationState &result,
|
||||
Value shape) {
|
||||
if (shape.getType().isa<ShapedType>()) {
|
||||
auto type = builder.getIndexType();
|
||||
return build(builder, result, type, shape);
|
||||
}
|
||||
auto type = SizeType::get(builder.getContext());
|
||||
return build(builder, result, type, shape);
|
||||
LogicalResult mlir::shape::NumElementsOp::inferReturnTypes(
|
||||
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
if (operands[0].getType().isa<ShapeType>())
|
||||
inferredReturnTypes.assign({SizeType::get(context)});
|
||||
else
|
||||
inferredReturnTypes.assign({IndexType::get(context)});
|
||||
return success();
|
||||
}
|
||||
|
||||
bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(TypeRange l,
|
||||
TypeRange r) {
|
||||
// SizeType is compatible with IndexType.
|
||||
return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1212,6 +1329,27 @@ OpFoldResult MaxOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
LogicalResult mlir::shape::MaxOp::inferReturnTypes(
|
||||
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
if (operands[0].getType() == operands[1].getType())
|
||||
inferredReturnTypes.assign({operands[0].getType()});
|
||||
else
|
||||
inferredReturnTypes.assign({SizeType::get(context)});
|
||||
return success();
|
||||
}
|
||||
|
||||
bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
|
||||
if (l.size() != 1 || r.size() != 1)
|
||||
return false;
|
||||
if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>())
|
||||
return true;
|
||||
if (l.front().isa<SizeType>() && r.front().isa<SizeType>())
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MinOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1223,6 +1361,27 @@ OpFoldResult MinOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
LogicalResult mlir::shape::MinOp::inferReturnTypes(
|
||||
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
if (operands[0].getType() == operands[1].getType())
|
||||
inferredReturnTypes.assign({operands[0].getType()});
|
||||
else
|
||||
inferredReturnTypes.assign({SizeType::get(context)});
|
||||
return success();
|
||||
}
|
||||
|
||||
bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
|
||||
if (l.size() != 1 || r.size() != 1)
|
||||
return false;
|
||||
if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>())
|
||||
return true;
|
||||
if (l.front().isa<SizeType>() && r.front().isa<SizeType>())
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MulOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1239,6 +1398,22 @@ OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
|
|||
return IntegerAttr::get(indexTy, folded);
|
||||
}
|
||||
|
||||
LogicalResult mlir::shape::MulOp::inferReturnTypes(
|
||||
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
if (operands[0].getType().isa<SizeType>() ||
|
||||
operands[1].getType().isa<SizeType>())
|
||||
inferredReturnTypes.assign({SizeType::get(context)});
|
||||
else
|
||||
inferredReturnTypes.assign({IndexType::get(context)});
|
||||
return success();
|
||||
}
|
||||
|
||||
bool mlir::shape::MulOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
|
||||
// SizeType is compatible with IndexType.
|
||||
return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
|
||||
}
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ShapeOfOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1251,18 +1426,6 @@ OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
|
|||
return builder.getIndexTensorAttr(type.getShape());
|
||||
}
|
||||
|
||||
void ShapeOfOp::build(OpBuilder &builder, OperationState &result, Value arg) {
|
||||
if (auto shapedTy = arg.getType().dyn_cast<ShapedType>()) {
|
||||
int64_t rank =
|
||||
shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize;
|
||||
Type indexTy = builder.getIndexType();
|
||||
Type extentTensorTy = RankedTensorType::get({rank}, indexTy);
|
||||
return ShapeOfOp::build(builder, result, extentTensorTy, arg);
|
||||
}
|
||||
Type shapeTy = builder.getType<ShapeType>();
|
||||
return ShapeOfOp::build(builder, result, shapeTy, arg);
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
|
||||
using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
|
||||
|
@ -1317,6 +1480,44 @@ void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|||
patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor>(context);
|
||||
}
|
||||
|
||||
LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(
|
||||
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
if (operands[0].getType().isa<ValueShapeType>())
|
||||
inferredReturnTypes.assign({ShapeType::get(context)});
|
||||
else {
|
||||
auto shapedTy = operands[0].getType().cast<ShapedType>();
|
||||
int64_t rank =
|
||||
shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize;
|
||||
Type indexTy = IndexType::get(context);
|
||||
Type extentTensorTy = RankedTensorType::get({rank}, indexTy);
|
||||
inferredReturnTypes.assign({extentTensorTy});
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
bool mlir::shape::ShapeOfOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
|
||||
if (l.size() != 1 || r.size() != 1)
|
||||
return false;
|
||||
if (l == r)
|
||||
return true;
|
||||
|
||||
Type lhs = l.front();
|
||||
Type rhs = r.front();
|
||||
|
||||
if (!lhs.isa<ShapeType, ShapedType>() || !rhs.isa<ShapeType, ShapedType>())
|
||||
return false;
|
||||
|
||||
if (lhs.isa<ShapeType>() || rhs.isa<ShapeType>())
|
||||
// Shape type is compatible with all other valid return types.
|
||||
return true;
|
||||
|
||||
if (succeeded(verifyCompatibleShapes({lhs, rhs})))
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SizeToIndexOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -97,6 +97,14 @@ func @shape_of(%value_arg : !shape.value_shape,
|
|||
|
||||
// -----
|
||||
|
||||
func @shape_of_incompatible_return_types(%value_arg : tensor<1x2xindex>) {
|
||||
// expected-error@+1 {{'shape.shape_of' op inferred type(s) 'tensor<2xindex>' are incompatible with return type(s) of operation 'tensor<3xf32>'}}
|
||||
%0 = shape.shape_of %value_arg : tensor<1x2xindex> -> tensor<3xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @rank(%arg : !shape.shape) {
|
||||
// expected-error@+1 {{if at least one of the operands can hold error values then the result must be of type `size` to propagate them}}
|
||||
%0 = shape.rank %arg : !shape.shape -> index
|
||||
|
|
Loading…
Reference in New Issue