[mlir,shape] Add max/min folder for simple case

When both arguments are the same for these ops, propagate this argument.
This commit is contained in:
Jacques Pienaar 2021-04-06 20:22:42 -07:00
parent 162c2759b6
commit 8b109bc2ea
3 changed files with 46 additions and 0 deletions

View File

@ -416,6 +416,8 @@ def Shape_MaxOp : Shape_Op<"max", [Commutative, NoSideEffect]> {
let assemblyFormat = [{
$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)
}];
let hasFolder = 1;
}
def Shape_MinOp : Shape_Op<"min", [Commutative, NoSideEffect]> {
@ -433,6 +435,8 @@ def Shape_MinOp : Shape_Op<"min", [Commutative, NoSideEffect]> {
let assemblyFormat = [{
$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)
}];
let hasFolder = 1;
}
def Shape_MulOp : Shape_Op<"mul", [Commutative, NoSideEffect]> {

View File

@ -937,6 +937,28 @@ void NumElementsOp::build(OpBuilder &builder, OperationState &result,
return build(builder, result, type, shape);
}
//===----------------------------------------------------------------------===//
// MaxOp
//===----------------------------------------------------------------------===//
OpFoldResult MaxOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
// If operands are equal, just propagate one.
if (lhs() == rhs())
return lhs();
return nullptr;
}
//===----------------------------------------------------------------------===//
// MinOp
//===----------------------------------------------------------------------===//
OpFoldResult MinOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
// If operands are equal, just propagate one.
if (lhs() == rhs())
return lhs();
return nullptr;
}
//===----------------------------------------------------------------------===//
// MulOp
//===----------------------------------------------------------------------===//

View File

@ -1188,3 +1188,23 @@ func @casted_extent_tensor(%arg : tensor<*xf32>) -> tensor<3xindex> {
%1 = tensor.cast %0 : tensor<?xindex> to tensor<3xindex>
return %1 : tensor<3xindex>
}
// ----
// CHECK-LABEL: max_same_arg
// CHECK-SAME: (%[[SHAPE:.*]]: !shape.shape)
func @max_same_arg(%a: !shape.shape) -> !shape.shape {
%1 = shape.max %a, %a : !shape.shape, !shape.shape -> !shape.shape
// CHECK: return %[[SHAPE]]
return %1 : !shape.shape
}
// ----
// CHECK-LABEL: min_same_arg
// CHECK-SAME: (%[[SHAPE:.*]]: !shape.shape)
func @min_same_arg(%a: !shape.shape) -> !shape.shape {
%1 = shape.min %a, %a : !shape.shape, !shape.shape -> !shape.shape
// CHECK: return %[[SHAPE]]
return %1 : !shape.shape
}