forked from OSchip/llvm-project
[MLIR][Shape] Fold `shape.mul`
Implement constant folding for `shape.mul`. Differential Revision: https://reviews.llvm.org/D84438
This commit is contained in:
parent
783a351785
commit
670ae4b6da
|
@ -326,6 +326,7 @@ def Shape_MulOp : Shape_Op<"mul", [Commutative, NoSideEffect]> {
|
|||
}];
|
||||
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Shape_NumElementsOp : Shape_Op<"num_elements", [NoSideEffect]> {
|
||||
|
|
|
@ -695,6 +695,18 @@ static LogicalResult verify(MulOp op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
|
||||
auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
|
||||
if (!lhs)
|
||||
return nullptr;
|
||||
auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
|
||||
if (!rhs)
|
||||
return nullptr;
|
||||
APInt folded = lhs.getValue() * rhs.getValue();
|
||||
Type indexTy = IndexType::get(getContext());
|
||||
return IntegerAttr::get(indexTy, folded);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ShapeOfOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -734,3 +734,43 @@ func @shape_eq_do_not_fold(%a : !shape.shape) -> i1 {
|
|||
%result = shape.shape_eq %a, %b : !shape.shape, !shape.shape
|
||||
return %result : i1
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Fold `mul` for constant sizes.
|
||||
// CHECK-LABEL: @fold_mul_size
|
||||
func @fold_mul_size() -> !shape.size {
|
||||
// CHECK: %[[RESULT:.*]] = shape.const_size 6
|
||||
// CHECK: return %[[RESULT]] : !shape.size
|
||||
%c2 = shape.const_size 2
|
||||
%c3 = shape.const_size 3
|
||||
%result = shape.mul %c2, %c3 : !shape.size, !shape.size -> !shape.size
|
||||
return %result : !shape.size
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Fold `mul` for constant indices.
|
||||
// CHECK-LABEL: @fold_mul_index
|
||||
func @fold_mul_index() -> index {
|
||||
// CHECK: %[[RESULT:.*]] = constant 6 : index
|
||||
// CHECK: return %[[RESULT]] : index
|
||||
%c2 = constant 2 : index
|
||||
%c3 = constant 3 : index
|
||||
%result = shape.mul %c2, %c3 : index, index -> index
|
||||
return %result : index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Fold `mul` for mixed constants.
|
||||
// CHECK-LABEL: @fold_mul_mixed
|
||||
func @fold_mul_mixed() -> !shape.size {
|
||||
// CHECK: %[[RESULT:.*]] = shape.const_size 6
|
||||
// CHECK: return %[[RESULT]] : !shape.size
|
||||
%c2 = shape.const_size 2
|
||||
%c3 = constant 3 : index
|
||||
%result = shape.mul %c2, %c3 : !shape.size, index -> !shape.size
|
||||
return %result : !shape.size
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue