diff --git a/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp index 2d1a66c301f8..c705dc87bfa8 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp @@ -24,24 +24,23 @@ using namespace mlir; // Common utility functions //===----------------------------------------------------------------------===// -/// Returns true if the given `irVal` is a scalar or splat vector constant of -/// the given `boolVal`. -static bool isScalarOrSplatBoolAttr(Attribute boolAttr, bool boolVal) { +/// Returns the boolean value under the hood if the given `boolAttr` is a scalar +/// or splat vector bool constant. +static Optional getScalarOrSplatBoolAttr(Attribute boolAttr) { if (!boolAttr) - return false; + return llvm::None; auto type = boolAttr.getType(); if (type.isInteger(1)) { auto attr = boolAttr.cast(); - return attr.getValue() == boolVal; + return attr.getValue(); } if (auto vecType = type.cast()) { if (vecType.getElementType().isInteger(1)) if (auto attr = boolAttr.dyn_cast()) - return attr.getSplatValue().template cast().getValue() == - boolVal; + return attr.getSplatValue(); } - return false; + return llvm::None; } // Extracts an element from the given `composite` by following the given @@ -214,13 +213,15 @@ OpFoldResult spirv::ISubOp::fold(ArrayRef operands) { OpFoldResult spirv::LogicalAndOp::fold(ArrayRef operands) { assert(operands.size() == 2 && "spv.LogicalAnd should take two operands"); - // x && true = x - if (isScalarOrSplatBoolAttr(operands.back(), true)) - return operand1(); + if (Optional rhs = getScalarOrSplatBoolAttr(operands.back())) { + // x && true = x + if (rhs.getValue()) + return operand1(); - // x && false = false - if (isScalarOrSplatBoolAttr(operands.back(), false)) - return operands.back(); + // x && false = false + if (!rhs.getValue()) + return operands.back(); + } return Attribute(); } @@ -243,13 +244,15 @@ void spirv::LogicalNotOp::getCanonicalizationPatterns( OpFoldResult spirv::LogicalOrOp::fold(ArrayRef operands) { assert(operands.size() == 2 && "spv.LogicalOr should take two operands"); - // x || true = true - if (isScalarOrSplatBoolAttr(operands.back(), true)) - return operands.back(); + if (auto rhs = getScalarOrSplatBoolAttr(operands.back())) { + if (rhs.getValue()) + // x || true = true + return operands.back(); - // x || false = x - if (isScalarOrSplatBoolAttr(operands.back(), false)) - return operand1(); + // x || false = x + if (!rhs.getValue()) + return operand1(); + } return Attribute(); }