forked from OSchip/llvm-project
[mlir,shape] Add max/min folder for simple case
When both arguments are the same for these ops, propagate this argument.
This commit is contained in:
parent
162c2759b6
commit
8b109bc2ea
|
@ -416,6 +416,8 @@ def Shape_MaxOp : Shape_Op<"max", [Commutative, NoSideEffect]> {
|
||||||
let assemblyFormat = [{
|
let assemblyFormat = [{
|
||||||
$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)
|
$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)
|
||||||
}];
|
}];
|
||||||
|
|
||||||
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def Shape_MinOp : Shape_Op<"min", [Commutative, NoSideEffect]> {
|
def Shape_MinOp : Shape_Op<"min", [Commutative, NoSideEffect]> {
|
||||||
|
@ -433,6 +435,8 @@ def Shape_MinOp : Shape_Op<"min", [Commutative, NoSideEffect]> {
|
||||||
let assemblyFormat = [{
|
let assemblyFormat = [{
|
||||||
$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)
|
$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)
|
||||||
}];
|
}];
|
||||||
|
|
||||||
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def Shape_MulOp : Shape_Op<"mul", [Commutative, NoSideEffect]> {
|
def Shape_MulOp : Shape_Op<"mul", [Commutative, NoSideEffect]> {
|
||||||
|
|
|
@ -937,6 +937,28 @@ void NumElementsOp::build(OpBuilder &builder, OperationState &result,
|
||||||
return build(builder, result, type, shape);
|
return build(builder, result, type, shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// MaxOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
OpFoldResult MaxOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
|
||||||
|
// If operands are equal, just propagate one.
|
||||||
|
if (lhs() == rhs())
|
||||||
|
return lhs();
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// MinOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
OpFoldResult MinOp::fold(llvm::ArrayRef<mlir::Attribute> operands) {
|
||||||
|
// If operands are equal, just propagate one.
|
||||||
|
if (lhs() == rhs())
|
||||||
|
return lhs();
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// MulOp
|
// MulOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -1188,3 +1188,23 @@ func @casted_extent_tensor(%arg : tensor<*xf32>) -> tensor<3xindex> {
|
||||||
%1 = tensor.cast %0 : tensor<?xindex> to tensor<3xindex>
|
%1 = tensor.cast %0 : tensor<?xindex> to tensor<3xindex>
|
||||||
return %1 : tensor<3xindex>
|
return %1 : tensor<3xindex>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ----
|
||||||
|
|
||||||
|
// CHECK-LABEL: max_same_arg
|
||||||
|
// CHECK-SAME: (%[[SHAPE:.*]]: !shape.shape)
|
||||||
|
func @max_same_arg(%a: !shape.shape) -> !shape.shape {
|
||||||
|
%1 = shape.max %a, %a : !shape.shape, !shape.shape -> !shape.shape
|
||||||
|
// CHECK: return %[[SHAPE]]
|
||||||
|
return %1 : !shape.shape
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----
|
||||||
|
|
||||||
|
// CHECK-LABEL: min_same_arg
|
||||||
|
// CHECK-SAME: (%[[SHAPE:.*]]: !shape.shape)
|
||||||
|
func @min_same_arg(%a: !shape.shape) -> !shape.shape {
|
||||||
|
%1 = shape.min %a, %a : !shape.shape, !shape.shape -> !shape.shape
|
||||||
|
// CHECK: return %[[SHAPE]]
|
||||||
|
return %1 : !shape.shape
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue