forked from OSchip/llvm-project
[mlir,shape] Add max/min folder for simple case
When both arguments are the same for these ops, propagate this argument.
This commit is contained in:
parent
162c2759b6
commit
8b109bc2ea
|
@ -416,6 +416,8 @@ def Shape_MaxOp : Shape_Op<"max", [Commutative, NoSideEffect]> {
|
|||
let assemblyFormat = [{
|
||||
$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)
|
||||
}];
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Shape_MinOp : Shape_Op<"min", [Commutative, NoSideEffect]> {
|
||||
|
@ -433,6 +435,8 @@ def Shape_MinOp : Shape_Op<"min", [Commutative, NoSideEffect]> {
|
|||
let assemblyFormat = [{
|
||||
$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)
|
||||
}];
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Shape_MulOp : Shape_Op<"mul", [Commutative, NoSideEffect]> {
|
||||
|
|
|
@ -937,6 +937,28 @@ void NumElementsOp::build(OpBuilder &builder, OperationState &result,
|
|||
return build(builder, result, type, shape);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MaxOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult MaxOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
|
||||
// If operands are equal, just propagate one.
|
||||
if (lhs() == rhs())
|
||||
return lhs();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MinOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult MinOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
|
||||
// If operands are equal, just propagate one.
|
||||
if (lhs() == rhs())
|
||||
return lhs();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MulOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -1188,3 +1188,23 @@ func @casted_extent_tensor(%arg : tensor<*xf32>) -> tensor<3xindex> {
|
|||
%1 = tensor.cast %0 : tensor<?xindex> to tensor<3xindex>
|
||||
return %1 : tensor<3xindex>
|
||||
}
|
||||
|
||||
// ----
|
||||
|
||||
// CHECK-LABEL: max_same_arg
|
||||
// CHECK-SAME: (%[[SHAPE:.*]]: !shape.shape)
|
||||
func @max_same_arg(%a: !shape.shape) -> !shape.shape {
|
||||
%1 = shape.max %a, %a : !shape.shape, !shape.shape -> !shape.shape
|
||||
// CHECK: return %[[SHAPE]]
|
||||
return %1 : !shape.shape
|
||||
}
|
||||
|
||||
// ----
|
||||
|
||||
// CHECK-LABEL: min_same_arg
|
||||
// CHECK-SAME: (%[[SHAPE:.*]]: !shape.shape)
|
||||
func @min_same_arg(%a: !shape.shape) -> !shape.shape {
|
||||
%1 = shape.min %a, %a : !shape.shape, !shape.shape -> !shape.shape
|
||||
// CHECK: return %[[SHAPE]]
|
||||
return %1 : !shape.shape
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue