forked from OSchip/llvm-project
[shape] Add inferReturnTypes to a couple ops.
- ShapeOfOp - BroadcastOp Differential Revision: https://reviews.llvm.org/D78822
This commit is contained in:
parent
5fff169daa
commit
57a7cd7a13
|
@ -130,7 +130,8 @@ def Shape_AddOp : Shape_Op<"add", [SameOperandsAndResultType]> {
|
|||
let results = (outs Shape_SizeType:$result);
|
||||
}
|
||||
|
||||
def Shape_BroadcastOp : Shape_Op<"broadcast", []> {
|
||||
def Shape_BroadcastOp : Shape_Op<"broadcast",
|
||||
[DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||
let summary = "Returns the broadcasted output shape of two inputs";
|
||||
let description = [{
|
||||
Computes the broadcasted output shape following:
|
||||
|
@ -317,7 +318,8 @@ def Shape_ReduceOp : Shape_Op<"reduce", []> {
|
|||
let regions = (region SizedRegion<1>:$body);
|
||||
}
|
||||
|
||||
def Shape_ShapeOfOp : Shape_Op<"shape_of", []> {
|
||||
def Shape_ShapeOfOp : Shape_Op<"shape_of",
|
||||
[DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||
let summary = "Returns shape of a value or shaped type operand";
|
||||
|
||||
let arguments = (ins AnyTypeOf<[AnyShaped, Shape_ValueShapeType]>:$arg);
|
||||
|
|
|
@ -92,6 +92,14 @@ void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
|
|||
// BroadcastOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult BroadcastOp::inferReturnTypes(
|
||||
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
||||
ArrayRef<NamedAttribute> attributes, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
inferredReturnTypes.push_back(ShapeType::get(context));
|
||||
return success();
|
||||
}
|
||||
|
||||
OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (!operands[0] || !operands[1])
|
||||
return nullptr;
|
||||
|
@ -175,6 +183,14 @@ LogicalResult ConstSizeOp::inferReturnTypes(
|
|||
// ShapeOfOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult ShapeOfOp::inferReturnTypes(
|
||||
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
||||
ArrayRef<NamedAttribute> attributes, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
inferredReturnTypes.push_back(ShapeType::get(context));
|
||||
return success();
|
||||
}
|
||||
|
||||
OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
|
||||
auto type = getOperand().getType().dyn_cast<ShapedType>();
|
||||
if (!type || !type.hasStaticShape())
|
||||
|
|
Loading…
Reference in New Issue