forked from OSchip/llvm-project
[mlir][IR] Add an Operation::eraseOperands that supports batch erasure
This method allows for removing multiple disjoint operands at once, reducing the need to erase operands individually (which results in shifting the operand list). Differential Revision: https://reviews.llvm.org/D98290
This commit is contained in:
parent
4a7aed4ee7
commit
a776ecb6c2
|
@ -243,6 +243,12 @@ public:
|
|||
getOperandStorage().eraseOperands(idx, length);
|
||||
}
|
||||
|
||||
/// Erases the operands that have their corresponding bit set in
|
||||
/// `eraseIndices` and removes them from the operand list.
|
||||
void eraseOperands(const llvm::BitVector &eraseIndices) {
|
||||
getOperandStorage().eraseOperands(eraseIndices);
|
||||
}
|
||||
|
||||
// Support operand iteration.
|
||||
using operand_range = OperandRange;
|
||||
using operand_iterator = operand_range::iterator;
|
||||
|
|
|
@ -28,6 +28,10 @@
|
|||
#include "llvm/Support/TrailingObjects.h"
|
||||
#include <memory>
|
||||
|
||||
namespace llvm {
|
||||
class BitVector;
|
||||
} // end namespace llvm
|
||||
|
||||
namespace mlir {
|
||||
class Dialect;
|
||||
class DictionaryAttr;
|
||||
|
@ -495,6 +499,10 @@ public:
|
|||
/// Erase the operands held by the storage within the given range.
|
||||
void eraseOperands(unsigned start, unsigned length);
|
||||
|
||||
/// Erase the operands held by the storage that have their corresponding bit
|
||||
/// set in `eraseIndices`.
|
||||
void eraseOperands(const llvm::BitVector &eraseIndices);
|
||||
|
||||
/// Get the operation operands held by the storage.
|
||||
MutableArrayRef<OpOperand> getOperands() {
|
||||
return getStorage().getOperands();
|
||||
|
|
|
@ -12,10 +12,10 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/IR/OperationSupport.h"
|
||||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "llvm/ADT/BitVector.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -300,6 +300,26 @@ void detail::OperandStorage::eraseOperands(unsigned start, unsigned length) {
|
|||
operands[storage.numOperands + i].~OpOperand();
|
||||
}
|
||||
|
||||
void detail::OperandStorage::eraseOperands(
|
||||
const llvm::BitVector &eraseIndices) {
|
||||
TrailingOperandStorage &storage = getStorage();
|
||||
MutableArrayRef<OpOperand> operands = storage.getOperands();
|
||||
assert(eraseIndices.size() == operands.size());
|
||||
|
||||
// Check that at least one operand is erased.
|
||||
int firstErasedIndice = eraseIndices.find_first();
|
||||
if (firstErasedIndice == -1)
|
||||
return;
|
||||
|
||||
// Shift all of the removed operands to the end, and destroy them.
|
||||
storage.numOperands = firstErasedIndice;
|
||||
for (unsigned i = firstErasedIndice + 1, e = operands.size(); i < e; ++i)
|
||||
if (!eraseIndices.test(i))
|
||||
operands[storage.numOperands++] = std::move(operands[i]);
|
||||
for (OpOperand &operand : operands.drop_front(storage.numOperands))
|
||||
operand.~OpOperand();
|
||||
}
|
||||
|
||||
/// Resize the storage to the given size. Returns the array containing the new
|
||||
/// operands.
|
||||
MutableArrayRef<OpOperand> detail::OperandStorage::resize(Operation *owner,
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
#include "mlir/IR/OperationSupport.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "llvm/ADT/BitVector.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
@ -150,6 +151,37 @@ TEST(OperandStorageTest, MutableRange) {
|
|||
useOp->destroy();
|
||||
}
|
||||
|
||||
TEST(OperandStorageTest, RangeErase) {
|
||||
MLIRContext context;
|
||||
Builder builder(&context);
|
||||
|
||||
Type type = builder.getNoneType();
|
||||
Operation *useOp = createOp(&context, /*operands=*/llvm::None, {type, type});
|
||||
Value operand1 = useOp->getResult(0);
|
||||
Value operand2 = useOp->getResult(1);
|
||||
|
||||
// Create an operation with operands to erase.
|
||||
Operation *user =
|
||||
createOp(&context, {operand2, operand1, operand2, operand1});
|
||||
llvm::BitVector eraseIndices(user->getNumOperands());
|
||||
|
||||
// Check erasing no operands.
|
||||
user->eraseOperands(eraseIndices);
|
||||
EXPECT_EQ(user->getNumOperands(), 4u);
|
||||
|
||||
// Check erasing disjoint operands.
|
||||
eraseIndices.set(0);
|
||||
eraseIndices.set(3);
|
||||
user->eraseOperands(eraseIndices);
|
||||
EXPECT_EQ(user->getNumOperands(), 2u);
|
||||
EXPECT_EQ(user->getOperand(0), operand1);
|
||||
EXPECT_EQ(user->getOperand(1), operand2);
|
||||
|
||||
// Destroy the operations.
|
||||
user->destroy();
|
||||
useOp->destroy();
|
||||
}
|
||||
|
||||
TEST(OperationOrderTest, OrderIsAlwaysValid) {
|
||||
MLIRContext context;
|
||||
Builder builder(&context);
|
||||
|
|
Loading…
Reference in New Issue