[MLIR][Shape] Fold `shape.mul`

Implement constant folding for `shape.mul`.

Differential Revision: https://reviews.llvm.org/D84438
This commit is contained in:
Frederik Gossen 2020-07-24 13:29:51 +00:00
parent 783a351785
commit 670ae4b6da
3 changed files with 53 additions and 0 deletions

View File

@ -326,6 +326,7 @@ def Shape_MulOp : Shape_Op<"mul", [Commutative, NoSideEffect]> {
}];
let verifier = [{ return ::verify(*this); }];
let hasFolder = 1;
}
def Shape_NumElementsOp : Shape_Op<"num_elements", [NoSideEffect]> {

View File

@ -695,6 +695,18 @@ static LogicalResult verify(MulOp op) {
return success();
}
OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
if (!lhs)
return nullptr;
auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
if (!rhs)
return nullptr;
APInt folded = lhs.getValue() * rhs.getValue();
Type indexTy = IndexType::get(getContext());
return IntegerAttr::get(indexTy, folded);
}
//===----------------------------------------------------------------------===//
// ShapeOfOp
//===----------------------------------------------------------------------===//

View File

@ -734,3 +734,43 @@ func @shape_eq_do_not_fold(%a : !shape.shape) -> i1 {
%result = shape.shape_eq %a, %b : !shape.shape, !shape.shape
return %result : i1
}
// -----
// Fold `mul` for constant sizes.
// CHECK-LABEL: @fold_mul_size
func @fold_mul_size() -> !shape.size {
// CHECK: %[[RESULT:.*]] = shape.const_size 6
// CHECK: return %[[RESULT]] : !shape.size
%c2 = shape.const_size 2
%c3 = shape.const_size 3
%result = shape.mul %c2, %c3 : !shape.size, !shape.size -> !shape.size
return %result : !shape.size
}
// -----
// Fold `mul` for constant indices.
// CHECK-LABEL: @fold_mul_index
func @fold_mul_index() -> index {
// CHECK: %[[RESULT:.*]] = constant 6 : index
// CHECK: return %[[RESULT]] : index
%c2 = constant 2 : index
%c3 = constant 3 : index
%result = shape.mul %c2, %c3 : index, index -> index
return %result : index
}
// -----
// Fold `mul` for mixed constants.
// CHECK-LABEL: @fold_mul_mixed
func @fold_mul_mixed() -> !shape.size {
// CHECK: %[[RESULT:.*]] = shape.const_size 6
// CHECK: return %[[RESULT]] : !shape.size
%c2 = shape.const_size 2
%c3 = constant 3 : index
%result = shape.mul %c2, %c3 : !shape.size, index -> !shape.size
return %result : !shape.size
}