[mlir][spirv] Refactoring to avoid calling the same function twice

This commit is contained in:
Lei Zhang 2020-02-26 15:35:46 -05:00
parent c753a306fd
commit 63779fb462
1 changed files with 23 additions and 20 deletions

View File

@ -24,24 +24,23 @@ using namespace mlir;
// Common utility functions
//===----------------------------------------------------------------------===//
/// Returns true if the given `irVal` is a scalar or splat vector constant of
/// the given `boolVal`.
static bool isScalarOrSplatBoolAttr(Attribute boolAttr, bool boolVal) {
/// Returns the boolean value under the hood if the given `boolAttr` is a scalar
/// or splat vector bool constant.
static Optional<bool> getScalarOrSplatBoolAttr(Attribute boolAttr) {
if (!boolAttr)
return false;
return llvm::None;
auto type = boolAttr.getType();
if (type.isInteger(1)) {
auto attr = boolAttr.cast<BoolAttr>();
return attr.getValue() == boolVal;
return attr.getValue();
}
if (auto vecType = type.cast<VectorType>()) {
if (vecType.getElementType().isInteger(1))
if (auto attr = boolAttr.dyn_cast<SplatElementsAttr>())
return attr.getSplatValue().template cast<BoolAttr>().getValue() ==
boolVal;
return attr.getSplatValue<bool>();
}
return false;
return llvm::None;
}
// Extracts an element from the given `composite` by following the given
@ -214,13 +213,15 @@ OpFoldResult spirv::ISubOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult spirv::LogicalAndOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2 && "spv.LogicalAnd should take two operands");
// x && true = x
if (isScalarOrSplatBoolAttr(operands.back(), true))
return operand1();
if (Optional<bool> rhs = getScalarOrSplatBoolAttr(operands.back())) {
// x && true = x
if (rhs.getValue())
return operand1();
// x && false = false
if (isScalarOrSplatBoolAttr(operands.back(), false))
return operands.back();
// x && false = false
if (!rhs.getValue())
return operands.back();
}
return Attribute();
}
@ -243,13 +244,15 @@ void spirv::LogicalNotOp::getCanonicalizationPatterns(
OpFoldResult spirv::LogicalOrOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2 && "spv.LogicalOr should take two operands");
// x || true = true
if (isScalarOrSplatBoolAttr(operands.back(), true))
return operands.back();
if (auto rhs = getScalarOrSplatBoolAttr(operands.back())) {
if (rhs.getValue())
// x || true = true
return operands.back();
// x || false = x
if (isScalarOrSplatBoolAttr(operands.back(), false))
return operand1();
// x || false = x
if (!rhs.getValue())
return operand1();
}
return Attribute();
}