From 57a7cd7a138fed24e109a02dbd8f7d464bf7e177 Mon Sep 17 00:00:00 2001 From: Sean Silva Date: Fri, 24 Apr 2020 15:54:22 -0700 Subject: [PATCH] [shape] Add inferReturnTypes to a couple ops. - ShapeOfOp - BroadcastOp Differential Revision: https://reviews.llvm.org/D78822 --- mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td | 6 ++++-- mlir/lib/Dialect/Shape/IR/Shape.cpp | 16 ++++++++++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td index f54456b862fa..fa277f4f89de 100644 --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -130,7 +130,8 @@ def Shape_AddOp : Shape_Op<"add", [SameOperandsAndResultType]> { let results = (outs Shape_SizeType:$result); } -def Shape_BroadcastOp : Shape_Op<"broadcast", []> { +def Shape_BroadcastOp : Shape_Op<"broadcast", + [DeclareOpInterfaceMethods]> { let summary = "Returns the broadcasted output shape of two inputs"; let description = [{ Computes the broadcasted output shape following: @@ -317,7 +318,8 @@ def Shape_ReduceOp : Shape_Op<"reduce", []> { let regions = (region SizedRegion<1>:$body); } -def Shape_ShapeOfOp : Shape_Op<"shape_of", []> { +def Shape_ShapeOfOp : Shape_Op<"shape_of", + [DeclareOpInterfaceMethods]> { let summary = "Returns shape of a value or shaped type operand"; let arguments = (ins AnyTypeOf<[AnyShaped, Shape_ValueShapeType]>:$arg); diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index 4a1c0f1d5128..10e766f3cc61 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -92,6 +92,14 @@ void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const { // BroadcastOp //===----------------------------------------------------------------------===// +LogicalResult BroadcastOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + ArrayRef attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + inferredReturnTypes.push_back(ShapeType::get(context)); + return success(); +} + OpFoldResult BroadcastOp::fold(ArrayRef operands) { if (!operands[0] || !operands[1]) return nullptr; @@ -175,6 +183,14 @@ LogicalResult ConstSizeOp::inferReturnTypes( // ShapeOfOp //===----------------------------------------------------------------------===// +LogicalResult ShapeOfOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + ArrayRef attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + inferredReturnTypes.push_back(ShapeType::get(context)); + return success(); +} + OpFoldResult ShapeOfOp::fold(ArrayRef) { auto type = getOperand().getType().dyn_cast(); if (!type || !type.hasStaticShape())