[shape] Add inferReturnTypes to a couple ops.

- ShapeOfOp
- BroadcastOp

Differential Revision: https://reviews.llvm.org/D78822
This commit is contained in:
Sean Silva 2020-04-24 15:54:22 -07:00
parent 5fff169daa
commit 57a7cd7a13
2 changed files with 20 additions and 2 deletions

View File

@ -130,7 +130,8 @@ def Shape_AddOp : Shape_Op<"add", [SameOperandsAndResultType]> {
let results = (outs Shape_SizeType:$result);
}
def Shape_BroadcastOp : Shape_Op<"broadcast", []> {
def Shape_BroadcastOp : Shape_Op<"broadcast",
[DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "Returns the broadcasted output shape of two inputs";
let description = [{
Computes the broadcasted output shape following:
@ -317,7 +318,8 @@ def Shape_ReduceOp : Shape_Op<"reduce", []> {
let regions = (region SizedRegion<1>:$body);
}
def Shape_ShapeOfOp : Shape_Op<"shape_of", []> {
def Shape_ShapeOfOp : Shape_Op<"shape_of",
[DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "Returns shape of a value or shaped type operand";
let arguments = (ins AnyTypeOf<[AnyShaped, Shape_ValueShapeType]>:$arg);

View File

@ -92,6 +92,14 @@ void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
// BroadcastOp
//===----------------------------------------------------------------------===//
LogicalResult BroadcastOp::inferReturnTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands,
ArrayRef<NamedAttribute> attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
inferredReturnTypes.push_back(ShapeType::get(context));
return success();
}
OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
if (!operands[0] || !operands[1])
return nullptr;
@ -175,6 +183,14 @@ LogicalResult ConstSizeOp::inferReturnTypes(
// ShapeOfOp
//===----------------------------------------------------------------------===//
LogicalResult ShapeOfOp::inferReturnTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands,
ArrayRef<NamedAttribute> attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
inferredReturnTypes.push_back(ShapeType::get(context));
return success();
}
OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
auto type = getOperand().getType().dyn_cast<ShapedType>();
if (!type || !type.hasStaticShape())