forked from OSchip/llvm-project
[mlir][spirv] Refactoring to avoid calling the same function twice
This commit is contained in:
parent
c753a306fd
commit
63779fb462
|
@ -24,24 +24,23 @@ using namespace mlir;
|
|||
// Common utility functions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Returns true if the given `irVal` is a scalar or splat vector constant of
|
||||
/// the given `boolVal`.
|
||||
static bool isScalarOrSplatBoolAttr(Attribute boolAttr, bool boolVal) {
|
||||
/// Returns the boolean value under the hood if the given `boolAttr` is a scalar
|
||||
/// or splat vector bool constant.
|
||||
static Optional<bool> getScalarOrSplatBoolAttr(Attribute boolAttr) {
|
||||
if (!boolAttr)
|
||||
return false;
|
||||
return llvm::None;
|
||||
|
||||
auto type = boolAttr.getType();
|
||||
if (type.isInteger(1)) {
|
||||
auto attr = boolAttr.cast<BoolAttr>();
|
||||
return attr.getValue() == boolVal;
|
||||
return attr.getValue();
|
||||
}
|
||||
if (auto vecType = type.cast<VectorType>()) {
|
||||
if (vecType.getElementType().isInteger(1))
|
||||
if (auto attr = boolAttr.dyn_cast<SplatElementsAttr>())
|
||||
return attr.getSplatValue().template cast<BoolAttr>().getValue() ==
|
||||
boolVal;
|
||||
return attr.getSplatValue<bool>();
|
||||
}
|
||||
return false;
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
// Extracts an element from the given `composite` by following the given
|
||||
|
@ -214,13 +213,15 @@ OpFoldResult spirv::ISubOp::fold(ArrayRef<Attribute> operands) {
|
|||
OpFoldResult spirv::LogicalAndOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.size() == 2 && "spv.LogicalAnd should take two operands");
|
||||
|
||||
// x && true = x
|
||||
if (isScalarOrSplatBoolAttr(operands.back(), true))
|
||||
return operand1();
|
||||
if (Optional<bool> rhs = getScalarOrSplatBoolAttr(operands.back())) {
|
||||
// x && true = x
|
||||
if (rhs.getValue())
|
||||
return operand1();
|
||||
|
||||
// x && false = false
|
||||
if (isScalarOrSplatBoolAttr(operands.back(), false))
|
||||
return operands.back();
|
||||
// x && false = false
|
||||
if (!rhs.getValue())
|
||||
return operands.back();
|
||||
}
|
||||
|
||||
return Attribute();
|
||||
}
|
||||
|
@ -243,13 +244,15 @@ void spirv::LogicalNotOp::getCanonicalizationPatterns(
|
|||
OpFoldResult spirv::LogicalOrOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.size() == 2 && "spv.LogicalOr should take two operands");
|
||||
|
||||
// x || true = true
|
||||
if (isScalarOrSplatBoolAttr(operands.back(), true))
|
||||
return operands.back();
|
||||
if (auto rhs = getScalarOrSplatBoolAttr(operands.back())) {
|
||||
if (rhs.getValue())
|
||||
// x || true = true
|
||||
return operands.back();
|
||||
|
||||
// x || false = x
|
||||
if (isScalarOrSplatBoolAttr(operands.back(), false))
|
||||
return operand1();
|
||||
// x || false = x
|
||||
if (!rhs.getValue())
|
||||
return operand1();
|
||||
}
|
||||
|
||||
return Attribute();
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue