forked from OSchip/llvm-project
[mlir:Function][NFC] Use BitVector instead of ArrayRef for indices when removing arguments/results
We already convert to BitVector internally, and other APIs (namely Operation::eraseOperands) already use BitVector as well. Switching over provides a common format between API and also reduces the amount of format conversions necessary. Fixes #53325 Differential Revision: https://reviews.llvm.org/D118083
This commit is contained in:
parent
f4a368689f
commit
e3cd80ea9f
|
@ -13,6 +13,7 @@
|
|||
#include "SubElementInterfaces.h"
|
||||
|
||||
namespace llvm {
|
||||
class BitVector;
|
||||
struct fltSemantics;
|
||||
} // namespace llvm
|
||||
|
||||
|
|
|
@ -166,8 +166,8 @@ def Builtin_Function : Builtin_Type<"Function", [
|
|||
TypeRange resultTypes);
|
||||
|
||||
/// Returns a new function type without the specified arguments and results.
|
||||
FunctionType getWithoutArgsAndResults(ArrayRef<unsigned> argIndices,
|
||||
ArrayRef<unsigned> resultIndices);
|
||||
FunctionType getWithoutArgsAndResults(const llvm::BitVector &argIndices,
|
||||
const llvm::BitVector &resultIndices);
|
||||
}];
|
||||
}
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "llvm/ADT/BitVector.h"
|
||||
#include "llvm/ADT/SmallString.h"
|
||||
|
||||
namespace mlir {
|
||||
|
@ -82,12 +83,12 @@ void insertFunctionResults(Operation *op, ArrayRef<unsigned> resultIndices,
|
|||
unsigned originalNumResults, Type newType);
|
||||
|
||||
/// Erase the specified arguments and update the function type attribute.
|
||||
void eraseFunctionArguments(Operation *op, ArrayRef<unsigned> argIndices,
|
||||
unsigned originalNumArgs, Type newType);
|
||||
void eraseFunctionArguments(Operation *op, const llvm::BitVector &argIndices,
|
||||
Type newType);
|
||||
|
||||
/// Erase the specified results and update the function type attribute.
|
||||
void eraseFunctionResults(Operation *op, ArrayRef<unsigned> resultIndices,
|
||||
unsigned originalNumResults, Type newType);
|
||||
void eraseFunctionResults(Operation *op, const llvm::BitVector &resultIndices,
|
||||
Type newType);
|
||||
|
||||
/// Set a FunctionOpInterface operation's type signature.
|
||||
void setFunctionType(Operation *op, Type newType);
|
||||
|
@ -100,7 +101,7 @@ TypeRange insertTypesInto(TypeRange oldTypes, ArrayRef<unsigned> indices,
|
|||
|
||||
/// Filters out any elements referenced by `indices`. If any types are removed,
|
||||
/// `storage` is used to hold the new type list. Returns the new type list.
|
||||
TypeRange filterTypesOut(TypeRange types, ArrayRef<unsigned> indices,
|
||||
TypeRange filterTypesOut(TypeRange types, const llvm::BitVector &indices,
|
||||
SmallVectorImpl<Type> &storage);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -280,27 +280,31 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> {
|
|||
}
|
||||
|
||||
/// Erase a single argument at `argIndex`.
|
||||
void eraseArgument(unsigned argIndex) { eraseArguments({argIndex}); }
|
||||
void eraseArgument(unsigned argIndex) {
|
||||
llvm::BitVector argsToErase($_op.getNumArguments());
|
||||
argsToErase.set(argIndex);
|
||||
eraseArguments(argsToErase);
|
||||
}
|
||||
|
||||
/// Erases the arguments listed in `argIndices`.
|
||||
/// `argIndices` is allowed to have duplicates and can be in any order.
|
||||
void eraseArguments(ArrayRef<unsigned> argIndices) {
|
||||
unsigned originalNumArgs = $_op.getNumArguments();
|
||||
Type newType = $_op.getTypeWithoutArgsAndResults(argIndices, {});
|
||||
function_interface_impl::eraseFunctionArguments(this->getOperation(), argIndices,
|
||||
originalNumArgs, newType);
|
||||
void eraseArguments(const llvm::BitVector &argIndices) {
|
||||
Type newType = $_op.getTypeWithoutArgs(argIndices);
|
||||
function_interface_impl::eraseFunctionArguments(
|
||||
this->getOperation(), argIndices, newType);
|
||||
}
|
||||
|
||||
/// Erase a single result at `resultIndex`.
|
||||
void eraseResult(unsigned resultIndex) { eraseResults({resultIndex}); }
|
||||
void eraseResult(unsigned resultIndex) {
|
||||
llvm::BitVector resultsToErase($_op.getNumResults());
|
||||
resultsToErase.set(resultIndex);
|
||||
eraseResults(resultsToErase);
|
||||
}
|
||||
|
||||
/// Erases the results listed in `resultIndices`.
|
||||
/// `resultIndices` is allowed to have duplicates and can be in any order.
|
||||
void eraseResults(ArrayRef<unsigned> resultIndices) {
|
||||
unsigned originalNumResults = $_op.getNumResults();
|
||||
Type newType = $_op.getTypeWithoutArgsAndResults({}, resultIndices);
|
||||
void eraseResults(const llvm::BitVector &resultIndices) {
|
||||
Type newType = $_op.getTypeWithoutResults(resultIndices);
|
||||
function_interface_impl::eraseFunctionResults(
|
||||
this->getOperation(), resultIndices, originalNumResults, newType);
|
||||
this->getOperation(), resultIndices, newType);
|
||||
}
|
||||
|
||||
/// Return the type of this function with the specified arguments and
|
||||
|
@ -320,10 +324,9 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> {
|
|||
|
||||
/// Return the type of this function without the specified arguments and
|
||||
/// results. This is used to update the function's signature in the
|
||||
/// `eraseArguments` and `eraseResults` methods. The arrays of indices are
|
||||
/// allowed to have duplicates and can be in any order.
|
||||
/// `eraseArguments` and `eraseResults` methods.
|
||||
Type getTypeWithoutArgsAndResults(
|
||||
ArrayRef<unsigned> argIndices, ArrayRef<unsigned> resultIndices) {
|
||||
const llvm::BitVector &argIndices, const llvm::BitVector &resultIndices) {
|
||||
SmallVector<Type> argStorage, resultStorage;
|
||||
TypeRange newArgTypes = function_interface_impl::filterTypesOut(
|
||||
$_op.getArgumentTypes(), argIndices, argStorage);
|
||||
|
@ -331,6 +334,18 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> {
|
|||
$_op.getResultTypes(), resultIndices, resultStorage);
|
||||
return $_op.cloneTypeWith(newArgTypes, newResultTypes);
|
||||
}
|
||||
Type getTypeWithoutArgs(const llvm::BitVector &argIndices) {
|
||||
SmallVector<Type> argStorage;
|
||||
TypeRange newArgTypes = function_interface_impl::filterTypesOut(
|
||||
$_op.getArgumentTypes(), argIndices, argStorage);
|
||||
return $_op.cloneTypeWith(newArgTypes, $_op.getResultTypes());
|
||||
}
|
||||
Type getTypeWithoutResults(const llvm::BitVector &resultIndices) {
|
||||
SmallVector<Type> resultStorage;
|
||||
TypeRange newResultTypes = function_interface_impl::filterTypesOut(
|
||||
$_op.getResultTypes(), resultIndices, resultStorage);
|
||||
return $_op.cloneTypeWith($_op.getArgumentTypes(), newResultTypes);
|
||||
}
|
||||
|
||||
//===------------------------------------------------------------------===//
|
||||
// Argument Attributes
|
||||
|
|
|
@ -24,10 +24,10 @@ static void updateFuncOp(FuncOp func,
|
|||
|
||||
// Collect information about the results will become appended arguments.
|
||||
SmallVector<Type, 6> erasedResultTypes;
|
||||
SmallVector<unsigned, 6> erasedResultIndices;
|
||||
llvm::BitVector erasedResultIndices(functionType.getNumResults());
|
||||
for (const auto &resultType : llvm::enumerate(functionType.getResults())) {
|
||||
if (resultType.value().isa<BaseMemRefType>()) {
|
||||
erasedResultIndices.push_back(resultType.index());
|
||||
erasedResultIndices.set(resultType.index());
|
||||
erasedResultTypes.push_back(resultType.value());
|
||||
}
|
||||
}
|
||||
|
@ -40,9 +40,11 @@ static void updateFuncOp(FuncOp func,
|
|||
func.setType(newFunctionType);
|
||||
|
||||
// Transfer the result attributes to arg attributes.
|
||||
for (int i = 0, e = erasedResultTypes.size(); i < e; i++)
|
||||
auto erasedIndicesIt = erasedResultIndices.set_bits_begin();
|
||||
for (int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) {
|
||||
func.setArgAttrs(functionType.getNumInputs() + i,
|
||||
func.getResultAttrs(erasedResultIndices[i]));
|
||||
func.getResultAttrs(*erasedIndicesIt));
|
||||
}
|
||||
|
||||
// Erase the results.
|
||||
func.eraseResults(erasedResultIndices);
|
||||
|
|
|
@ -172,8 +172,8 @@ FunctionType FunctionType::getWithArgsAndResults(
|
|||
|
||||
/// Returns a new function type without the specified arguments and results.
|
||||
FunctionType
|
||||
FunctionType::getWithoutArgsAndResults(ArrayRef<unsigned> argIndices,
|
||||
ArrayRef<unsigned> resultIndices) {
|
||||
FunctionType::getWithoutArgsAndResults(const llvm::BitVector &argIndices,
|
||||
const llvm::BitVector &resultIndices) {
|
||||
SmallVector<Type> argStorage, resultStorage;
|
||||
TypeRange newArgTypes = function_interface_impl::filterTypesOut(
|
||||
getInputs(), argIndices, argStorage);
|
||||
|
|
|
@ -7,26 +7,9 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/IR/FunctionInterfaces.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/ADT/BitVector.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
/// Helper to call a callback once on each index in the range
|
||||
/// [0, `totalIndices`), *except* for the indices given in `indices`.
|
||||
/// `indices` is allowed to have duplicates and can be in any order.
|
||||
inline static void iterateIndicesExcept(unsigned totalIndices,
|
||||
ArrayRef<unsigned> indices,
|
||||
function_ref<void(unsigned)> callback) {
|
||||
llvm::BitVector skipIndices(totalIndices);
|
||||
for (unsigned i : indices)
|
||||
skipIndices.set(i);
|
||||
|
||||
for (unsigned i = 0; i < totalIndices; ++i)
|
||||
if (!skipIndices.test(i))
|
||||
callback(i);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Tablegen Interface Definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -217,8 +200,7 @@ void mlir::function_interface_impl::insertFunctionResults(
|
|||
}
|
||||
|
||||
void mlir::function_interface_impl::eraseFunctionArguments(
|
||||
Operation *op, ArrayRef<unsigned> argIndices, unsigned originalNumArgs,
|
||||
Type newType) {
|
||||
Operation *op, const llvm::BitVector &argIndices, Type newType) {
|
||||
// There are 3 things that need to be updated:
|
||||
// - Function type.
|
||||
// - Arg attrs.
|
||||
|
@ -229,9 +211,9 @@ void mlir::function_interface_impl::eraseFunctionArguments(
|
|||
if (auto argAttrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName())) {
|
||||
SmallVector<DictionaryAttr, 4> newArgAttrs;
|
||||
newArgAttrs.reserve(argAttrs.size());
|
||||
iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) {
|
||||
newArgAttrs.emplace_back(argAttrs[i].cast<DictionaryAttr>());
|
||||
});
|
||||
for (unsigned i = 0, e = argIndices.size(); i < e; ++i)
|
||||
if (!argIndices[i])
|
||||
newArgAttrs.emplace_back(argAttrs[i].cast<DictionaryAttr>());
|
||||
setAllArgAttrDicts(op, newArgAttrs);
|
||||
}
|
||||
|
||||
|
@ -241,8 +223,7 @@ void mlir::function_interface_impl::eraseFunctionArguments(
|
|||
}
|
||||
|
||||
void mlir::function_interface_impl::eraseFunctionResults(
|
||||
Operation *op, ArrayRef<unsigned> resultIndices,
|
||||
unsigned originalNumResults, Type newType) {
|
||||
Operation *op, const llvm::BitVector &resultIndices, Type newType) {
|
||||
// There are 2 things that need to be updated:
|
||||
// - Function type.
|
||||
// - Result attrs.
|
||||
|
@ -251,9 +232,9 @@ void mlir::function_interface_impl::eraseFunctionResults(
|
|||
if (auto resAttrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName())) {
|
||||
SmallVector<DictionaryAttr, 4> newResultAttrs;
|
||||
newResultAttrs.reserve(resAttrs.size());
|
||||
iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) {
|
||||
newResultAttrs.emplace_back(resAttrs[i].cast<DictionaryAttr>());
|
||||
});
|
||||
for (unsigned i = 0, e = resultIndices.size(); i < e; ++i)
|
||||
if (!resultIndices[i])
|
||||
newResultAttrs.emplace_back(resAttrs[i].cast<DictionaryAttr>());
|
||||
setAllResultAttrDicts(op, newResultAttrs);
|
||||
}
|
||||
|
||||
|
@ -282,12 +263,14 @@ TypeRange mlir::function_interface_impl::insertTypesInto(
|
|||
|
||||
TypeRange
|
||||
mlir::function_interface_impl::filterTypesOut(TypeRange types,
|
||||
ArrayRef<unsigned> indices,
|
||||
const llvm::BitVector &indices,
|
||||
SmallVectorImpl<Type> &storage) {
|
||||
if (indices.empty())
|
||||
if (indices.none())
|
||||
return types;
|
||||
iterateIndicesExcept(types.size(), indices,
|
||||
[&](unsigned i) { storage.emplace_back(types[i]); });
|
||||
|
||||
for (unsigned i = 0, e = types.size(); i < e; ++i)
|
||||
if (!indices[i])
|
||||
storage.emplace_back(types[i]);
|
||||
return storage;
|
||||
}
|
||||
|
||||
|
|
|
@ -87,18 +87,10 @@ struct TestFuncEraseArg
|
|||
auto module = getOperation();
|
||||
|
||||
for (FuncOp func : module.getOps<FuncOp>()) {
|
||||
SmallVector<unsigned, 4> indicesToErase;
|
||||
for (auto argIndex : llvm::seq<int>(0, func.getNumArguments())) {
|
||||
if (func.getArgAttr(argIndex, "test.erase_this_arg")) {
|
||||
// Push back twice to test that duplicate arg indices are handled
|
||||
// correctly.
|
||||
indicesToErase.push_back(argIndex);
|
||||
indicesToErase.push_back(argIndex);
|
||||
}
|
||||
}
|
||||
// Reverse the order to test that unsorted index lists are handled
|
||||
// correctly.
|
||||
std::reverse(indicesToErase.begin(), indicesToErase.end());
|
||||
llvm::BitVector indicesToErase(func.getNumArguments());
|
||||
for (auto argIndex : llvm::seq<int>(0, func.getNumArguments()))
|
||||
if (func.getArgAttr(argIndex, "test.erase_this_arg"))
|
||||
indicesToErase.set(argIndex);
|
||||
func.eraseArguments(indicesToErase);
|
||||
}
|
||||
}
|
||||
|
@ -115,18 +107,10 @@ struct TestFuncEraseResult
|
|||
auto module = getOperation();
|
||||
|
||||
for (FuncOp func : module.getOps<FuncOp>()) {
|
||||
SmallVector<unsigned, 4> indicesToErase;
|
||||
for (auto resultIndex : llvm::seq<int>(0, func.getNumResults())) {
|
||||
if (func.getResultAttr(resultIndex, "test.erase_this_result")) {
|
||||
// Push back twice to test that duplicate indices are handled
|
||||
// correctly.
|
||||
indicesToErase.push_back(resultIndex);
|
||||
indicesToErase.push_back(resultIndex);
|
||||
}
|
||||
}
|
||||
// Reverse the order to test that unsorted index lists are handled
|
||||
// correctly.
|
||||
std::reverse(indicesToErase.begin(), indicesToErase.end());
|
||||
llvm::BitVector indicesToErase(func.getNumResults());
|
||||
for (auto resultIndex : llvm::seq<int>(0, func.getNumResults()))
|
||||
if (func.getResultAttr(resultIndex, "test.erase_this_result"))
|
||||
indicesToErase.set(resultIndex);
|
||||
func.eraseResults(indicesToErase);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue