[mlir][Arithmetic] Add common constant folder function for type cast ops.

This revision replaces current type cast constant folder with a new common type cast constant folder function template.
It will cover all former folder and support fold the constant splat and vector.

Differential Revision: https://reviews.llvm.org/D123489
This commit is contained in:
jacquesguan 2022-04-11 09:24:43 +00:00
parent 47a9528fb4
commit 605fc89a61
3 changed files with 283 additions and 69 deletions

View File

@ -108,6 +108,56 @@ Attribute constFoldUnaryOp(ArrayRef<Attribute> operands,
return {}; return {};
} }
template <
class AttrElementT, class TargetAttrElementT,
class ElementValueT = typename AttrElementT::ValueType,
class TargetElementValueT = typename TargetAttrElementT::ValueType,
class CalculationT = function_ref<TargetElementValueT(ElementValueT, bool)>>
Attribute constFoldCastOp(ArrayRef<Attribute> operands, Type resType,
const CalculationT &calculate) {
assert(operands.size() == 1 && "Cast op takes one operand");
if (!operands[0])
return {};
if (operands[0].isa<AttrElementT>()) {
auto op = operands[0].cast<AttrElementT>();
bool castStatus = true;
auto res = calculate(op.getValue(), castStatus);
if (!castStatus)
return {};
return TargetAttrElementT::get(resType, res);
}
if (operands[0].isa<SplatElementsAttr>()) {
// The operand is a splat so we can avoid expanding the values out and
// just fold based on the splat value.
auto op = operands[0].cast<SplatElementsAttr>();
bool castStatus = true;
auto elementResult =
calculate(op.getSplatValue<ElementValueT>(), castStatus);
if (!castStatus)
return {};
return DenseElementsAttr::get(resType, elementResult);
}
if (operands[0].isa<ElementsAttr>()) {
// Operand is ElementsAttr-derived; perform an element-wise fold by
// expanding the value.
auto op = operands[0].cast<ElementsAttr>();
bool castStatus = true;
auto opIt = op.value_begin<ElementValueT>();
SmallVector<TargetElementValueT> elementResults;
elementResults.reserve(op.getNumElements());
for (size_t i = 0, e = op.getNumElements(); i < e; ++i, ++opIt) {
auto elt = calculate(*opIt, castStatus);
if (!castStatus)
return {};
elementResults.push_back(elt);
}
return DenseElementsAttr::get(resType, elementResults);
}
return {};
}
} // namespace mlir } // namespace mlir
#endif // MLIR_DIALECT_COMMONFOLDERS_H #endif // MLIR_DIALECT_COMMONFOLDERS_H

View File

@ -875,16 +875,20 @@ static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) { OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) {
if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>())
return IntegerAttr::get(
getType(), lhs.getValue().zext(getType().getIntOrFloatBitWidth()));
if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) { if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
getInMutable().assign(lhs.getIn()); getInMutable().assign(lhs.getIn());
return getResult(); return getResult();
} }
Type resType = getType();
return {}; unsigned bitWidth;
if (auto shapedType = resType.dyn_cast<ShapedType>())
bitWidth = shapedType.getElementTypeBitWidth();
else
bitWidth = resType.getIntOrFloatBitWidth();
return constFoldCastOp<IntegerAttr, IntegerAttr>(
operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
return a.zext(bitWidth);
});
} }
bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
@ -900,16 +904,20 @@ LogicalResult arith::ExtUIOp::verify() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) { OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) {
if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>())
return IntegerAttr::get(
getType(), lhs.getValue().sext(getType().getIntOrFloatBitWidth()));
if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) { if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
getInMutable().assign(lhs.getIn()); getInMutable().assign(lhs.getIn());
return getResult(); return getResult();
} }
Type resType = getType();
return {}; unsigned bitWidth;
if (auto shapedType = resType.dyn_cast<ShapedType>())
bitWidth = shapedType.getElementTypeBitWidth();
else
bitWidth = resType.getIntOrFloatBitWidth();
return constFoldCastOp<IntegerAttr, IntegerAttr>(
operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
return a.sext(bitWidth);
});
} }
bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
@ -954,15 +962,17 @@ OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) {
return getResult(); return getResult();
} }
if (!operands[0]) Type resType = getType();
return {}; unsigned bitWidth;
if (auto shapedType = resType.dyn_cast<ShapedType>())
bitWidth = shapedType.getElementTypeBitWidth();
else
bitWidth = resType.getIntOrFloatBitWidth();
if (auto lhs = operands[0].dyn_cast<IntegerAttr>()) { return constFoldCastOp<IntegerAttr, IntegerAttr>(
return IntegerAttr::get( operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
getType(), lhs.getValue().trunc(getType().getIntOrFloatBitWidth())); return a.trunc(bitWidth);
} });
return {};
} }
bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
@ -1048,15 +1058,21 @@ bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
} }
OpFoldResult arith::UIToFPOp::fold(ArrayRef<Attribute> operands) { OpFoldResult arith::UIToFPOp::fold(ArrayRef<Attribute> operands) {
if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) { Type resType = getType();
const APInt &api = lhs.getValue(); Type resEleType;
FloatType floatTy = getType().cast<FloatType>(); if (auto shapedType = resType.dyn_cast<ShapedType>())
APFloat apf(floatTy.getFloatSemantics(), resEleType = shapedType.getElementType();
APInt::getZero(floatTy.getWidth())); else
apf.convertFromAPInt(api, /*IsSigned=*/false, APFloat::rmNearestTiesToEven); resEleType = resType;
return FloatAttr::get(floatTy, apf); return constFoldCastOp<IntegerAttr, FloatAttr>(
} operands, getType(), [&resEleType](const APInt &a, bool &castStatus) {
return {}; FloatType floatTy = resEleType.cast<FloatType>();
APFloat apf(floatTy.getFloatSemantics(),
APInt::getZero(floatTy.getWidth()));
apf.convertFromAPInt(a, /*IsSigned=*/false,
APFloat::rmNearestTiesToEven);
return apf;
});
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1068,15 +1084,21 @@ bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
} }
OpFoldResult arith::SIToFPOp::fold(ArrayRef<Attribute> operands) { OpFoldResult arith::SIToFPOp::fold(ArrayRef<Attribute> operands) {
if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) { Type resType = getType();
const APInt &api = lhs.getValue(); Type resEleType;
FloatType floatTy = getType().cast<FloatType>(); if (auto shapedType = resType.dyn_cast<ShapedType>())
APFloat apf(floatTy.getFloatSemantics(), resEleType = shapedType.getElementType();
APInt::getZero(floatTy.getWidth())); else
apf.convertFromAPInt(api, /*IsSigned=*/true, APFloat::rmNearestTiesToEven); resEleType = resType;
return FloatAttr::get(floatTy, apf); return constFoldCastOp<IntegerAttr, FloatAttr>(
} operands, getType(), [&resEleType](const APInt &a, bool &castStatus) {
return {}; FloatType floatTy = resEleType.cast<FloatType>();
APFloat apf(floatTy.getFloatSemantics(),
APInt::getZero(floatTy.getWidth()));
apf.convertFromAPInt(a, /*IsSigned=*/true,
APFloat::rmNearestTiesToEven);
return apf;
});
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// FPToUIOp // FPToUIOp
@ -1087,21 +1109,21 @@ bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
} }
OpFoldResult arith::FPToUIOp::fold(ArrayRef<Attribute> operands) { OpFoldResult arith::FPToUIOp::fold(ArrayRef<Attribute> operands) {
if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) { Type resType = getType();
const APFloat &apf = lhs.getValue(); Type resEleType;
IntegerType intTy = getType().cast<IntegerType>(); if (auto shapedType = resType.dyn_cast<ShapedType>())
bool ignored; resEleType = shapedType.getElementType();
APSInt api(intTy.getWidth(), /*isUnsigned=*/true); else
if (APFloat::opInvalidOp == resEleType = resType;
apf.convertToInteger(api, APFloat::rmTowardZero, &ignored)) { return constFoldCastOp<FloatAttr, IntegerAttr>(
// Undefined behavior invoked - the destination type can't represent operands, getType(), [&resEleType](const APFloat &a, bool &castStatus) {
// the input constant. IntegerType intTy = resEleType.cast<IntegerType>();
return {}; bool ignored;
} APSInt api(intTy.getWidth(), /*isUnsigned=*/true);
return IntegerAttr::get(getType(), api); castStatus = APFloat::opInvalidOp !=
} a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
return api;
return {}; });
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1113,21 +1135,21 @@ bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
} }
OpFoldResult arith::FPToSIOp::fold(ArrayRef<Attribute> operands) { OpFoldResult arith::FPToSIOp::fold(ArrayRef<Attribute> operands) {
if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) { Type resType = getType();
const APFloat &apf = lhs.getValue(); Type resEleType;
IntegerType intTy = getType().cast<IntegerType>(); if (auto shapedType = resType.dyn_cast<ShapedType>())
bool ignored; resEleType = shapedType.getElementType();
APSInt api(intTy.getWidth(), /*isUnsigned=*/false); else
if (APFloat::opInvalidOp == resEleType = resType;
apf.convertToInteger(api, APFloat::rmTowardZero, &ignored)) { return constFoldCastOp<FloatAttr, IntegerAttr>(
// Undefined behavior invoked - the destination type can't represent operands, getType(), [&resEleType](const APFloat &a, bool &castStatus) {
// the input constant. IntegerType intTy = resEleType.cast<IntegerType>();
return {}; bool ignored;
} APSInt api(intTy.getWidth(), /*isUnsigned=*/false);
return IntegerAttr::get(getType(), api); castStatus = APFloat::opInvalidOp !=
} a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
return api;
return {}; });
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -282,6 +282,53 @@ func @signExtendConstant() -> i16 {
return %ext : i16 return %ext : i16
} }
// CHECK-LABEL: @signExtendConstantSplat
// CHECK: %[[cres:.+]] = arith.constant dense<-2> : vector<4xi16>
// CHECK: return %[[cres]]
func @signExtendConstantSplat() -> vector<4xi16> {
%c-2 = arith.constant -2 : i8
%splat = vector.splat %c-2 : vector<4xi8>
%ext = arith.extsi %splat : vector<4xi8> to vector<4xi16>
return %ext : vector<4xi16>
}
// CHECK-LABEL: @signExtendConstantVector
// CHECK: %[[cres:.+]] = arith.constant dense<[1, 3, 5, 7]> : vector<4xi16>
// CHECK: return %[[cres]]
func @signExtendConstantVector() -> vector<4xi16> {
%vector = arith.constant dense<[1, 3, 5, 7]> : vector<4xi8>
%ext = arith.extsi %vector : vector<4xi8> to vector<4xi16>
return %ext : vector<4xi16>
}
// CHECK-LABEL: @unsignedExtendConstant
// CHECK: %[[cres:.+]] = arith.constant 2 : i16
// CHECK: return %[[cres]]
func @unsignedExtendConstant() -> i16 {
%c2 = arith.constant 2 : i8
%ext = arith.extui %c2 : i8 to i16
return %ext : i16
}
// CHECK-LABEL: @unsignedExtendConstantSplat
// CHECK: %[[cres:.+]] = arith.constant dense<2> : vector<4xi16>
// CHECK: return %[[cres]]
func @unsignedExtendConstantSplat() -> vector<4xi16> {
%c2 = arith.constant 2 : i8
%splat = vector.splat %c2 : vector<4xi8>
%ext = arith.extui %splat : vector<4xi8> to vector<4xi16>
return %ext : vector<4xi16>
}
// CHECK-LABEL: @unsignedExtendConstantVector
// CHECK: %[[cres:.+]] = arith.constant dense<[1, 3, 5, 7]> : vector<4xi16>
// CHECK: return %[[cres]]
func @unsignedExtendConstantVector() -> vector<4xi16> {
%vector = arith.constant dense<[1, 3, 5, 7]> : vector<4xi8>
%ext = arith.extui %vector : vector<4xi8> to vector<4xi16>
return %ext : vector<4xi16>
}
// CHECK-LABEL: @truncConstant // CHECK-LABEL: @truncConstant
// CHECK: %[[cres:.+]] = arith.constant -2 : i16 // CHECK: %[[cres:.+]] = arith.constant -2 : i16
// CHECK: return %[[cres]] // CHECK: return %[[cres]]
@ -291,6 +338,25 @@ func @truncConstant(%arg0: i8) -> i16 {
return %tr : i16 return %tr : i16
} }
// CHECK-LABEL: @truncConstantSplat
// CHECK: %[[cres:.+]] = arith.constant dense<-2> : vector<4xi8>
// CHECK: return %[[cres]]
func @truncConstantSplat() -> vector<4xi8> {
%c-2 = arith.constant -2 : i16
%splat = vector.splat %c-2 : vector<4xi16>
%trunc = arith.trunci %splat : vector<4xi16> to vector<4xi8>
return %trunc : vector<4xi8>
}
// CHECK-LABEL: @truncConstantVector
// CHECK: %[[cres:.+]] = arith.constant dense<[1, 3, 5, 7]> : vector<4xi8>
// CHECK: return %[[cres]]
func @truncConstantVector() -> vector<4xi8> {
%vector = arith.constant dense<[1, 3, 5, 7]> : vector<4xi16>
%trunc = arith.trunci %vector : vector<4xi16> to vector<4xi8>
return %trunc : vector<4xi8>
}
// CHECK-LABEL: @truncTrunc // CHECK-LABEL: @truncTrunc
// CHECK: %[[cres:.+]] = arith.trunci %arg0 : i64 to i8 // CHECK: %[[cres:.+]] = arith.trunci %arg0 : i64 to i8
// CHECK: return %[[cres]] // CHECK: return %[[cres]]
@ -921,6 +987,25 @@ func @constant_FPtoUI() -> i32 {
return %res : i32 return %res : i32
} }
// CHECK-LABEL: @constant_FPtoUI_splat(
func @constant_FPtoUI_splat() -> vector<4xi32> {
// CHECK: %[[C0:.+]] = arith.constant dense<2> : vector<4xi32>
// CHECK: return %[[C0]]
%c0 = arith.constant 2.0 : f32
%splat = vector.splat %c0 : vector<4xf32>
%res = arith.fptoui %splat : vector<4xf32> to vector<4xi32>
return %res : vector<4xi32>
}
// CHECK-LABEL: @constant_FPtoUI_vector(
func @constant_FPtoUI_vector() -> vector<4xi32> {
// CHECK: %[[C0:.+]] = arith.constant dense<[1, 3, 5, 7]> : vector<4xi32>
// CHECK: return %[[C0]]
%vector = arith.constant dense<[1.0, 3.0, 5.0, 7.0]> : vector<4xf32>
%res = arith.fptoui %vector : vector<4xf32> to vector<4xi32>
return %res : vector<4xi32>
}
// ----- // -----
// CHECK-LABEL: @invalid_constant_FPtoUI( // CHECK-LABEL: @invalid_constant_FPtoUI(
func @invalid_constant_FPtoUI() -> i32 { func @invalid_constant_FPtoUI() -> i32 {
@ -942,6 +1027,25 @@ func @constant_FPtoSI() -> i32 {
return %res : i32 return %res : i32
} }
// CHECK-LABEL: @constant_FPtoSI_splat(
func @constant_FPtoSI_splat() -> vector<4xi32> {
// CHECK: %[[C0:.+]] = arith.constant dense<-2> : vector<4xi32>
// CHECK: return %[[C0]]
%c0 = arith.constant -2.0 : f32
%splat = vector.splat %c0 : vector<4xf32>
%res = arith.fptosi %splat : vector<4xf32> to vector<4xi32>
return %res : vector<4xi32>
}
// CHECK-LABEL: @constant_FPtoSI_vector(
func @constant_FPtoSI_vector() -> vector<4xi32> {
// CHECK: %[[C0:.+]] = arith.constant dense<[-1, -3, -5, -7]> : vector<4xi32>
// CHECK: return %[[C0]]
%vector = arith.constant dense<[-1.0, -3.0, -5.0, -7.0]> : vector<4xf32>
%res = arith.fptosi %vector : vector<4xf32> to vector<4xi32>
return %res : vector<4xi32>
}
// ----- // -----
// CHECK-LABEL: @invalid_constant_FPtoSI( // CHECK-LABEL: @invalid_constant_FPtoSI(
func @invalid_constant_FPtoSI() -> i8 { func @invalid_constant_FPtoSI() -> i8 {
@ -962,16 +1066,54 @@ func @constant_SItoFP() -> f32 {
return %res : f32 return %res : f32
} }
// CHECK-LABEL: @constant_SItoFP_splat(
func @constant_SItoFP_splat() -> vector<4xf32> {
// CHECK: %[[C0:.+]] = arith.constant dense<2.000000e+00> : vector<4xf32>
// CHECK: return %[[C0]]
%c0 = arith.constant 2 : i32
%splat = vector.splat %c0 : vector<4xi32>
%res = arith.sitofp %splat : vector<4xi32> to vector<4xf32>
return %res : vector<4xf32>
}
// CHECK-LABEL: @constant_SItoFP_vector(
func @constant_SItoFP_vector() -> vector<4xf32> {
// CHECK: %[[C0:.+]] = arith.constant dense<[1.000000e+00, 3.000000e+00, 5.000000e+00, 7.000000e+00]> : vector<4xf32>
// CHECK: return %[[C0]]
%vector = arith.constant dense<[1, 3, 5, 7]> : vector<4xi32>
%res = arith.sitofp %vector : vector<4xi32> to vector<4xf32>
return %res : vector<4xf32>
}
// ----- // -----
// CHECK-LABEL: @constant_UItoFP( // CHECK-LABEL: @constant_UItoFP(
func @constant_UItoFP() -> f32 { func @constant_UItoFP() -> f32 {
// CHECK: %[[C0:.+]] = arith.constant 2.000000e+00 : f32 // CHECK: %[[C0:.+]] = arith.constant 2.000000e+00 : f32
// CHECK: return %[[C0]] // CHECK: return %[[C0]]
%c0 = arith.constant 2 : i32 %c0 = arith.constant 2 : i32
%res = arith.sitofp %c0 : i32 to f32 %res = arith.uitofp %c0 : i32 to f32
return %res : f32 return %res : f32
} }
// CHECK-LABEL: @constant_UItoFP_splat(
func @constant_UItoFP_splat() -> vector<4xf32> {
// CHECK: %[[C0:.+]] = arith.constant dense<2.000000e+00> : vector<4xf32>
// CHECK: return %[[C0]]
%c0 = arith.constant 2 : i32
%splat = vector.splat %c0 : vector<4xi32>
%res = arith.uitofp %splat : vector<4xi32> to vector<4xf32>
return %res : vector<4xf32>
}
// CHECK-LABEL: @constant_UItoFP_vector(
func @constant_UItoFP_vector() -> vector<4xf32> {
// CHECK: %[[C0:.+]] = arith.constant dense<[1.000000e+00, 3.000000e+00, 5.000000e+00, 7.000000e+00]> : vector<4xf32>
// CHECK: return %[[C0]]
%vector = arith.constant dense<[1, 3, 5, 7]> : vector<4xi32>
%res = arith.uitofp %vector : vector<4xi32> to vector<4xf32>
return %res : vector<4xf32>
}
// ----- // -----
// Tests rewritten from https://github.com/llvm/llvm-project/blob/main/llvm/test/Transforms/InstCombine/2008-11-08-FCmp.ll // Tests rewritten from https://github.com/llvm/llvm-project/blob/main/llvm/test/Transforms/InstCombine/2008-11-08-FCmp.ll