[mlir][IR] Add an Operation::eraseOperands that supports batch erasure

This method allows for removing multiple disjoint operands at once, reducing the need to erase operands individually (which results in shifting the operand list).

Differential Revision: https://reviews.llvm.org/D98290
This commit is contained in:
River Riddle 2021-03-09 15:02:03 -08:00
parent 4a7aed4ee7
commit a776ecb6c2
4 changed files with 68 additions and 2 deletions

View File

@ -243,6 +243,12 @@ public:
getOperandStorage().eraseOperands(idx, length); getOperandStorage().eraseOperands(idx, length);
} }
/// Erases the operands that have their corresponding bit set in
/// `eraseIndices` and removes them from the operand list.
void eraseOperands(const llvm::BitVector &eraseIndices) {
getOperandStorage().eraseOperands(eraseIndices);
}
// Support operand iteration. // Support operand iteration.
using operand_range = OperandRange; using operand_range = OperandRange;
using operand_iterator = operand_range::iterator; using operand_iterator = operand_range::iterator;

View File

@ -28,6 +28,10 @@
#include "llvm/Support/TrailingObjects.h" #include "llvm/Support/TrailingObjects.h"
#include <memory> #include <memory>
namespace llvm {
class BitVector;
} // end namespace llvm
namespace mlir { namespace mlir {
class Dialect; class Dialect;
class DictionaryAttr; class DictionaryAttr;
@ -495,6 +499,10 @@ public:
/// Erase the operands held by the storage within the given range. /// Erase the operands held by the storage within the given range.
void eraseOperands(unsigned start, unsigned length); void eraseOperands(unsigned start, unsigned length);
/// Erase the operands held by the storage that have their corresponding bit
/// set in `eraseIndices`.
void eraseOperands(const llvm::BitVector &eraseIndices);
/// Get the operation operands held by the storage. /// Get the operation operands held by the storage.
MutableArrayRef<OpOperand> getOperands() { MutableArrayRef<OpOperand> getOperands() {
return getStorage().getOperands(); return getStorage().getOperands();

View File

@ -12,10 +12,10 @@
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "mlir/IR/OperationSupport.h" #include "mlir/IR/OperationSupport.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Operation.h" #include "llvm/ADT/BitVector.h"
using namespace mlir; using namespace mlir;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -300,6 +300,26 @@ void detail::OperandStorage::eraseOperands(unsigned start, unsigned length) {
operands[storage.numOperands + i].~OpOperand(); operands[storage.numOperands + i].~OpOperand();
} }
void detail::OperandStorage::eraseOperands(
const llvm::BitVector &eraseIndices) {
TrailingOperandStorage &storage = getStorage();
MutableArrayRef<OpOperand> operands = storage.getOperands();
assert(eraseIndices.size() == operands.size());
// Check that at least one operand is erased.
int firstErasedIndice = eraseIndices.find_first();
if (firstErasedIndice == -1)
return;
// Shift all of the removed operands to the end, and destroy them.
storage.numOperands = firstErasedIndice;
for (unsigned i = firstErasedIndice + 1, e = operands.size(); i < e; ++i)
if (!eraseIndices.test(i))
operands[storage.numOperands++] = std::move(operands[i]);
for (OpOperand &operand : operands.drop_front(storage.numOperands))
operand.~OpOperand();
}
/// Resize the storage to the given size. Returns the array containing the new /// Resize the storage to the given size. Returns the array containing the new
/// operands. /// operands.
MutableArrayRef<OpOperand> detail::OperandStorage::resize(Operation *owner, MutableArrayRef<OpOperand> detail::OperandStorage::resize(Operation *owner,

View File

@ -9,6 +9,7 @@
#include "mlir/IR/OperationSupport.h" #include "mlir/IR/OperationSupport.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/BitVector.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
using namespace mlir; using namespace mlir;
@ -150,6 +151,37 @@ TEST(OperandStorageTest, MutableRange) {
useOp->destroy(); useOp->destroy();
} }
TEST(OperandStorageTest, RangeErase) {
MLIRContext context;
Builder builder(&context);
Type type = builder.getNoneType();
Operation *useOp = createOp(&context, /*operands=*/llvm::None, {type, type});
Value operand1 = useOp->getResult(0);
Value operand2 = useOp->getResult(1);
// Create an operation with operands to erase.
Operation *user =
createOp(&context, {operand2, operand1, operand2, operand1});
llvm::BitVector eraseIndices(user->getNumOperands());
// Check erasing no operands.
user->eraseOperands(eraseIndices);
EXPECT_EQ(user->getNumOperands(), 4u);
// Check erasing disjoint operands.
eraseIndices.set(0);
eraseIndices.set(3);
user->eraseOperands(eraseIndices);
EXPECT_EQ(user->getNumOperands(), 2u);
EXPECT_EQ(user->getOperand(0), operand1);
EXPECT_EQ(user->getOperand(1), operand2);
// Destroy the operations.
user->destroy();
useOp->destroy();
}
TEST(OperationOrderTest, OrderIsAlwaysValid) { TEST(OperationOrderTest, OrderIsAlwaysValid) {
MLIRContext context; MLIRContext context;
Builder builder(&context); Builder builder(&context);