From e3cd80ea9f0ac0d04f537feb70d8f9a1c7875863 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Mon, 24 Jan 2022 15:18:04 -0800 Subject: [PATCH] [mlir:Function][NFC] Use BitVector instead of ArrayRef for indices when removing arguments/results We already convert to BitVector internally, and other APIs (namely Operation::eraseOperands) already use BitVector as well. Switching over provides a common format between API and also reduces the amount of format conversions necessary. Fixes #53325 Differential Revision: https://reviews.llvm.org/D118083 --- mlir/include/mlir/IR/BuiltinTypes.h | 1 + mlir/include/mlir/IR/BuiltinTypes.td | 4 +- mlir/include/mlir/IR/FunctionInterfaces.h | 11 +++-- mlir/include/mlir/IR/FunctionInterfaces.td | 47 ++++++++++++------- .../Transforms/BufferResultsToOutParams.cpp | 10 ++-- mlir/lib/IR/BuiltinTypes.cpp | 4 +- mlir/lib/IR/FunctionInterfaces.cpp | 45 ++++++------------ mlir/test/lib/IR/TestFunc.cpp | 32 ++++--------- 8 files changed, 70 insertions(+), 84 deletions(-) diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h index e087e6bf55d8..b03d9ea9f575 100644 --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -13,6 +13,7 @@ #include "SubElementInterfaces.h" namespace llvm { +class BitVector; struct fltSemantics; } // namespace llvm diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td index b6f90f92190a..a78413548ba8 100644 --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -166,8 +166,8 @@ def Builtin_Function : Builtin_Type<"Function", [ TypeRange resultTypes); /// Returns a new function type without the specified arguments and results. - FunctionType getWithoutArgsAndResults(ArrayRef argIndices, - ArrayRef resultIndices); + FunctionType getWithoutArgsAndResults(const llvm::BitVector &argIndices, + const llvm::BitVector &resultIndices); }]; } diff --git a/mlir/include/mlir/IR/FunctionInterfaces.h b/mlir/include/mlir/IR/FunctionInterfaces.h index b81abceef614..b6d6a9515ffe 100644 --- a/mlir/include/mlir/IR/FunctionInterfaces.h +++ b/mlir/include/mlir/IR/FunctionInterfaces.h @@ -16,6 +16,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" +#include "llvm/ADT/BitVector.h" #include "llvm/ADT/SmallString.h" namespace mlir { @@ -82,12 +83,12 @@ void insertFunctionResults(Operation *op, ArrayRef resultIndices, unsigned originalNumResults, Type newType); /// Erase the specified arguments and update the function type attribute. -void eraseFunctionArguments(Operation *op, ArrayRef argIndices, - unsigned originalNumArgs, Type newType); +void eraseFunctionArguments(Operation *op, const llvm::BitVector &argIndices, + Type newType); /// Erase the specified results and update the function type attribute. -void eraseFunctionResults(Operation *op, ArrayRef resultIndices, - unsigned originalNumResults, Type newType); +void eraseFunctionResults(Operation *op, const llvm::BitVector &resultIndices, + Type newType); /// Set a FunctionOpInterface operation's type signature. void setFunctionType(Operation *op, Type newType); @@ -100,7 +101,7 @@ TypeRange insertTypesInto(TypeRange oldTypes, ArrayRef indices, /// Filters out any elements referenced by `indices`. If any types are removed, /// `storage` is used to hold the new type list. Returns the new type list. -TypeRange filterTypesOut(TypeRange types, ArrayRef indices, +TypeRange filterTypesOut(TypeRange types, const llvm::BitVector &indices, SmallVectorImpl &storage); //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/FunctionInterfaces.td b/mlir/include/mlir/IR/FunctionInterfaces.td index 20c7d7bbd51b..124f6594081f 100644 --- a/mlir/include/mlir/IR/FunctionInterfaces.td +++ b/mlir/include/mlir/IR/FunctionInterfaces.td @@ -280,27 +280,31 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> { } /// Erase a single argument at `argIndex`. - void eraseArgument(unsigned argIndex) { eraseArguments({argIndex}); } + void eraseArgument(unsigned argIndex) { + llvm::BitVector argsToErase($_op.getNumArguments()); + argsToErase.set(argIndex); + eraseArguments(argsToErase); + } /// Erases the arguments listed in `argIndices`. - /// `argIndices` is allowed to have duplicates and can be in any order. - void eraseArguments(ArrayRef argIndices) { - unsigned originalNumArgs = $_op.getNumArguments(); - Type newType = $_op.getTypeWithoutArgsAndResults(argIndices, {}); - function_interface_impl::eraseFunctionArguments(this->getOperation(), argIndices, - originalNumArgs, newType); + void eraseArguments(const llvm::BitVector &argIndices) { + Type newType = $_op.getTypeWithoutArgs(argIndices); + function_interface_impl::eraseFunctionArguments( + this->getOperation(), argIndices, newType); } /// Erase a single result at `resultIndex`. - void eraseResult(unsigned resultIndex) { eraseResults({resultIndex}); } + void eraseResult(unsigned resultIndex) { + llvm::BitVector resultsToErase($_op.getNumResults()); + resultsToErase.set(resultIndex); + eraseResults(resultsToErase); + } /// Erases the results listed in `resultIndices`. - /// `resultIndices` is allowed to have duplicates and can be in any order. - void eraseResults(ArrayRef resultIndices) { - unsigned originalNumResults = $_op.getNumResults(); - Type newType = $_op.getTypeWithoutArgsAndResults({}, resultIndices); + void eraseResults(const llvm::BitVector &resultIndices) { + Type newType = $_op.getTypeWithoutResults(resultIndices); function_interface_impl::eraseFunctionResults( - this->getOperation(), resultIndices, originalNumResults, newType); + this->getOperation(), resultIndices, newType); } /// Return the type of this function with the specified arguments and @@ -320,10 +324,9 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> { /// Return the type of this function without the specified arguments and /// results. This is used to update the function's signature in the - /// `eraseArguments` and `eraseResults` methods. The arrays of indices are - /// allowed to have duplicates and can be in any order. + /// `eraseArguments` and `eraseResults` methods. Type getTypeWithoutArgsAndResults( - ArrayRef argIndices, ArrayRef resultIndices) { + const llvm::BitVector &argIndices, const llvm::BitVector &resultIndices) { SmallVector argStorage, resultStorage; TypeRange newArgTypes = function_interface_impl::filterTypesOut( $_op.getArgumentTypes(), argIndices, argStorage); @@ -331,6 +334,18 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface"> { $_op.getResultTypes(), resultIndices, resultStorage); return $_op.cloneTypeWith(newArgTypes, newResultTypes); } + Type getTypeWithoutArgs(const llvm::BitVector &argIndices) { + SmallVector argStorage; + TypeRange newArgTypes = function_interface_impl::filterTypesOut( + $_op.getArgumentTypes(), argIndices, argStorage); + return $_op.cloneTypeWith(newArgTypes, $_op.getResultTypes()); + } + Type getTypeWithoutResults(const llvm::BitVector &resultIndices) { + SmallVector resultStorage; + TypeRange newResultTypes = function_interface_impl::filterTypesOut( + $_op.getResultTypes(), resultIndices, resultStorage); + return $_op.cloneTypeWith($_op.getArgumentTypes(), newResultTypes); + } //===------------------------------------------------------------------===// // Argument Attributes diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp index 08780db5f94d..585e873b4188 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp @@ -24,10 +24,10 @@ static void updateFuncOp(FuncOp func, // Collect information about the results will become appended arguments. SmallVector erasedResultTypes; - SmallVector erasedResultIndices; + llvm::BitVector erasedResultIndices(functionType.getNumResults()); for (const auto &resultType : llvm::enumerate(functionType.getResults())) { if (resultType.value().isa()) { - erasedResultIndices.push_back(resultType.index()); + erasedResultIndices.set(resultType.index()); erasedResultTypes.push_back(resultType.value()); } } @@ -40,9 +40,11 @@ static void updateFuncOp(FuncOp func, func.setType(newFunctionType); // Transfer the result attributes to arg attributes. - for (int i = 0, e = erasedResultTypes.size(); i < e; i++) + auto erasedIndicesIt = erasedResultIndices.set_bits_begin(); + for (int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) { func.setArgAttrs(functionType.getNumInputs() + i, - func.getResultAttrs(erasedResultIndices[i])); + func.getResultAttrs(*erasedIndicesIt)); + } // Erase the results. func.eraseResults(erasedResultIndices); diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp index 46166f16f168..63a596280390 100644 --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -172,8 +172,8 @@ FunctionType FunctionType::getWithArgsAndResults( /// Returns a new function type without the specified arguments and results. FunctionType -FunctionType::getWithoutArgsAndResults(ArrayRef argIndices, - ArrayRef resultIndices) { +FunctionType::getWithoutArgsAndResults(const llvm::BitVector &argIndices, + const llvm::BitVector &resultIndices) { SmallVector argStorage, resultStorage; TypeRange newArgTypes = function_interface_impl::filterTypesOut( getInputs(), argIndices, argStorage); diff --git a/mlir/lib/IR/FunctionInterfaces.cpp b/mlir/lib/IR/FunctionInterfaces.cpp index 07da5ce1716f..4f31c59f3f69 100644 --- a/mlir/lib/IR/FunctionInterfaces.cpp +++ b/mlir/lib/IR/FunctionInterfaces.cpp @@ -7,26 +7,9 @@ //===----------------------------------------------------------------------===// #include "mlir/IR/FunctionInterfaces.h" -#include "mlir/Support/LLVM.h" -#include "llvm/ADT/BitVector.h" using namespace mlir; -/// Helper to call a callback once on each index in the range -/// [0, `totalIndices`), *except* for the indices given in `indices`. -/// `indices` is allowed to have duplicates and can be in any order. -inline static void iterateIndicesExcept(unsigned totalIndices, - ArrayRef indices, - function_ref callback) { - llvm::BitVector skipIndices(totalIndices); - for (unsigned i : indices) - skipIndices.set(i); - - for (unsigned i = 0; i < totalIndices; ++i) - if (!skipIndices.test(i)) - callback(i); -} - //===----------------------------------------------------------------------===// // Tablegen Interface Definitions //===----------------------------------------------------------------------===// @@ -217,8 +200,7 @@ void mlir::function_interface_impl::insertFunctionResults( } void mlir::function_interface_impl::eraseFunctionArguments( - Operation *op, ArrayRef argIndices, unsigned originalNumArgs, - Type newType) { + Operation *op, const llvm::BitVector &argIndices, Type newType) { // There are 3 things that need to be updated: // - Function type. // - Arg attrs. @@ -229,9 +211,9 @@ void mlir::function_interface_impl::eraseFunctionArguments( if (auto argAttrs = op->getAttrOfType(getArgDictAttrName())) { SmallVector newArgAttrs; newArgAttrs.reserve(argAttrs.size()); - iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) { - newArgAttrs.emplace_back(argAttrs[i].cast()); - }); + for (unsigned i = 0, e = argIndices.size(); i < e; ++i) + if (!argIndices[i]) + newArgAttrs.emplace_back(argAttrs[i].cast()); setAllArgAttrDicts(op, newArgAttrs); } @@ -241,8 +223,7 @@ void mlir::function_interface_impl::eraseFunctionArguments( } void mlir::function_interface_impl::eraseFunctionResults( - Operation *op, ArrayRef resultIndices, - unsigned originalNumResults, Type newType) { + Operation *op, const llvm::BitVector &resultIndices, Type newType) { // There are 2 things that need to be updated: // - Function type. // - Result attrs. @@ -251,9 +232,9 @@ void mlir::function_interface_impl::eraseFunctionResults( if (auto resAttrs = op->getAttrOfType(getResultDictAttrName())) { SmallVector newResultAttrs; newResultAttrs.reserve(resAttrs.size()); - iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) { - newResultAttrs.emplace_back(resAttrs[i].cast()); - }); + for (unsigned i = 0, e = resultIndices.size(); i < e; ++i) + if (!resultIndices[i]) + newResultAttrs.emplace_back(resAttrs[i].cast()); setAllResultAttrDicts(op, newResultAttrs); } @@ -282,12 +263,14 @@ TypeRange mlir::function_interface_impl::insertTypesInto( TypeRange mlir::function_interface_impl::filterTypesOut(TypeRange types, - ArrayRef indices, + const llvm::BitVector &indices, SmallVectorImpl &storage) { - if (indices.empty()) + if (indices.none()) return types; - iterateIndicesExcept(types.size(), indices, - [&](unsigned i) { storage.emplace_back(types[i]); }); + + for (unsigned i = 0, e = types.size(); i < e; ++i) + if (!indices[i]) + storage.emplace_back(types[i]); return storage; } diff --git a/mlir/test/lib/IR/TestFunc.cpp b/mlir/test/lib/IR/TestFunc.cpp index dee9f8a5b2e5..0b4ce1d05992 100644 --- a/mlir/test/lib/IR/TestFunc.cpp +++ b/mlir/test/lib/IR/TestFunc.cpp @@ -87,18 +87,10 @@ struct TestFuncEraseArg auto module = getOperation(); for (FuncOp func : module.getOps()) { - SmallVector indicesToErase; - for (auto argIndex : llvm::seq(0, func.getNumArguments())) { - if (func.getArgAttr(argIndex, "test.erase_this_arg")) { - // Push back twice to test that duplicate arg indices are handled - // correctly. - indicesToErase.push_back(argIndex); - indicesToErase.push_back(argIndex); - } - } - // Reverse the order to test that unsorted index lists are handled - // correctly. - std::reverse(indicesToErase.begin(), indicesToErase.end()); + llvm::BitVector indicesToErase(func.getNumArguments()); + for (auto argIndex : llvm::seq(0, func.getNumArguments())) + if (func.getArgAttr(argIndex, "test.erase_this_arg")) + indicesToErase.set(argIndex); func.eraseArguments(indicesToErase); } } @@ -115,18 +107,10 @@ struct TestFuncEraseResult auto module = getOperation(); for (FuncOp func : module.getOps()) { - SmallVector indicesToErase; - for (auto resultIndex : llvm::seq(0, func.getNumResults())) { - if (func.getResultAttr(resultIndex, "test.erase_this_result")) { - // Push back twice to test that duplicate indices are handled - // correctly. - indicesToErase.push_back(resultIndex); - indicesToErase.push_back(resultIndex); - } - } - // Reverse the order to test that unsorted index lists are handled - // correctly. - std::reverse(indicesToErase.begin(), indicesToErase.end()); + llvm::BitVector indicesToErase(func.getNumResults()); + for (auto resultIndex : llvm::seq(0, func.getNumResults())) + if (func.getResultAttr(resultIndex, "test.erase_this_result")) + indicesToErase.set(resultIndex); func.eraseResults(indicesToErase); } }