From 1b88bbf5eb80b38a4dee129df969d5632993fdd1 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Wed, 2 Sep 2020 09:24:36 -0400 Subject: [PATCH] Revert "[mlir] Extend BufferAssignmentTypeConverter with result conversion callbacks" This reverts commit 94f5d248772ba0f1f9c8b0746fe75a5d246c5540 because of failing the following tests: MLIR :: Dialect/Linalg/tensors-to-buffers.mlir MLIR :: Transforms/buffer-placement-preparation-allowed-memref-results.mlir MLIR :: Transforms/buffer-placement-preparation.mlir --- .../include/mlir/Transforms/BufferPlacement.h | 344 +++++++++--------- .../Linalg/Transforms/TensorsToBuffers.cpp | 11 +- mlir/lib/Transforms/BufferPlacement.cpp | 220 +---------- ...nt-preparation-allowed-memref-results.mlir | 66 ---- .../buffer-placement-preparation.mlir | 85 ----- mlir/test/lib/Dialect/Test/TestOps.td | 29 +- .../lib/Transforms/TestBufferPlacement.cpp | 48 +-- 7 files changed, 191 insertions(+), 612 deletions(-) diff --git a/mlir/include/mlir/Transforms/BufferPlacement.h b/mlir/include/mlir/Transforms/BufferPlacement.h index 8fc254e6be1e..f8559a9dd939 100644 --- a/mlir/include/mlir/Transforms/BufferPlacement.h +++ b/mlir/include/mlir/Transforms/BufferPlacement.h @@ -52,111 +52,6 @@ private: Operation *operation; }; -/// A helper type converter class for using inside Buffer Assignment operation -/// conversion patterns. The default constructor keeps all the types intact -/// except for the ranked-tensor types which is converted to memref types. -class BufferAssignmentTypeConverter : public TypeConverter { -public: - /// This enum is for showing how buffer placement operation converters should - /// conduct with certain result type after type conversion. This value can be - /// set/get for each specific type using setResultConversionKind or - /// getResultConversionKind. - enum ResultConversionKind { AppendToArgumentsList, KeepAsFunctionResult }; - - BufferAssignmentTypeConverter(); - - /// This method tries to decompose a value of a certain type using provided - /// decompose callback functions. If it is unable to do so, the original value - /// is returned. - void tryDecomposeValue(OpBuilder &, Location, Type, Value, - SmallVectorImpl &); - - /// This method tries to decompose a type using provided decompose callback - /// functions. If it is unable to do so, the original type is returned. - void tryDecomposeType(Type, SmallVectorImpl &); - - /// This method registers a callback function that will be called to decompose - /// a value of a certain type into several values. - template ::template arg_t<2>> - void addDecomposeValueConversion(FnT &&callback) { - decomposeValueConversions.emplace_back( - wrapDecomposeValueConversionCallback(std::forward(callback))); - } - - /// This method registers a callback function that will be called to decompose - /// a type into several types. - template ::template arg_t<0>> - void addDecomposeTypeConversion(FnT &&callback) { - auto wrapper = - wrapDecomposeTypeConversionCallback(std::forward(callback)); - decomposeTypeConversions.emplace_back(wrapper); - addConversion(std::forward(callback)); - } - - /// This method returns ResultConversionKind for the mapping from `origin` - /// type to `input` type. - ResultConversionKind getResultConversionKind(Type origin, Type input); - - /// This method registers ResultConversionKind for the mapping from type 'T' - /// to type 'U'. - template - void setResultConversionKind(ResultConversionKind kind) { - assert((kind != AppendToArgumentsList || - llvm::is_one_of::value) && - "Only the memref typed values can be set to be appended to the " - "function argument list at the moment"); - resultTypeConversions.emplace_back( - [&](Type origin, Type input) -> Optional { - if (origin.template isa() && input.template isa()) - return kind; - return llvm::None; - }); - } - -private: - using DecomposeValueConversionCallFn = std::function( - OpBuilder &, Location, Type, Value, SmallVectorImpl &)>; - - using DecomposeTypeConversionCallFn = - std::function(Type, SmallVectorImpl &)>; - - using ResultConversionKindFn = - std::function(Type, Type)>; - - /// Generate a wrapper for the given decompose value conversion callback. - template - DecomposeValueConversionCallFn - wrapDecomposeValueConversionCallback(FnT &&callback) { - return [callback = std::forward(callback)]( - OpBuilder &builder, Location loc, Type type, Value value, - SmallVectorImpl &newValues) -> Optional { - if (T derivedType = type.dyn_cast()) - return callback(builder, loc, derivedType, value, newValues); - return llvm::None; - }; - } - - /// Generate a wrapper for the given decompose type conversion callback. - template - DecomposeTypeConversionCallFn - wrapDecomposeTypeConversionCallback(FnT &&callback) { - return [callback = std::forward(callback)]( - Type type, - SmallVectorImpl &results) -> Optional { - T derivedType = type.dyn_cast(); - if (!derivedType) - return llvm::None; - return callback(derivedType, results); - }; - } - - SmallVector resultTypeConversions; - SmallVector decomposeValueConversions; - SmallVector decomposeTypeConversions; -}; - /// Helper conversion pattern that encapsulates a BufferAssignmentPlacer /// instance. Sample usage: /// class CustomConversionPattern : public @@ -173,22 +68,43 @@ class BufferAssignmentOpConversionPattern public: explicit BufferAssignmentOpConversionPattern( MLIRContext *context, BufferAssignmentPlacer *bufferAssignment = nullptr, - BufferAssignmentTypeConverter *converter = nullptr, - PatternBenefit benefit = 1) + TypeConverter *converter = nullptr, PatternBenefit benefit = 1) : OpConversionPattern(context, benefit), - bufferAssignment(bufferAssignment), converter(converter) { - assert(converter && "The type converter has not been defined"); - } + bufferAssignment(bufferAssignment), converter(converter) {} protected: BufferAssignmentPlacer *bufferAssignment; - BufferAssignmentTypeConverter *converter; + TypeConverter *converter; }; -/// Converts the signature of the function using BufferAssignmentTypeConverter. -/// Each result type of the function is kept as a function result or appended to -/// the function arguments list based on ResultConversionKind for the converted -/// result type. +/// A helper type converter class for using inside Buffer Assignment operation +/// conversion patterns. The default constructor keeps all the types intact +/// except for the ranked-tensor types which is converted to memref types. +class BufferAssignmentTypeConverter : public TypeConverter { +public: + BufferAssignmentTypeConverter(); + + /// A helper function to check if `type` has been converted from non-memref + /// type to memref. + static bool isConvertedMemref(Type type, Type before); +}; + +namespace detail { + +/// Converts the signature of the function based on whether the function is +/// allowed to return memref typed results or not using +/// `allowMemrefFunctionResults` parameter. If this option is false, then it +/// adds an extra function argument as an output buffer for each function result +/// which is going to be a memref type only after type conversion. The +/// other function result types remain unchanged. If +/// `allowMemrefFunctionResults` is true, the types are converted in place. +/// Any changes in function signature need to be applied +/// to return and caller operations. `BufferAssignmentReturnOpConverter` and +/// `BufferAssignmentCallOpConverter` are two helper function that match the +/// return and caller operation with the new function signature. Furthermore, +/// `BufferAssignmentTypeConverter` is a helper `TypeConverter` for converting +/// tensor typed values to memref typed ones. +template class BufferAssignmentFuncOpConverter : public BufferAssignmentOpConversionPattern { public: @@ -196,16 +112,58 @@ public: FuncOp>::BufferAssignmentOpConversionPattern; /// Performs the actual signature rewriting step. - LogicalResult matchAndRewrite(mlir::FuncOp, ArrayRef, - ConversionPatternRewriter &) const; + LogicalResult + matchAndRewrite(mlir::FuncOp funcOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + if (!converter) + return funcOp.emitError("The type converter has not been defined for " + "BufferAssignmentFuncOpConverter"); + auto funcType = funcOp.getType(); + + // Convert function arguments using the provided TypeConverter. + TypeConverter::SignatureConversion conversion(funcType.getNumInputs()); + for (auto argType : llvm::enumerate(funcType.getInputs())) + conversion.addInputs(argType.index(), + converter->convertType(argType.value())); + + // If allowMemrefFunctionResults is false and a function result type is not + // a memref but it would be a memref after type conversion, a new argument + // should be appended to the function arguments list for this result. + // Otherwise, it remains unchanged as a function result. + SmallVector newResultTypes; + newResultTypes.reserve(funcOp.getNumResults()); + for (Type resType : funcType.getResults()) { + Type convertedType = converter->convertType(resType); + if (!allowMemrefFunctionResults && + BufferAssignmentTypeConverter::isConvertedMemref(convertedType, + resType)) + conversion.addInputs(convertedType); + else + newResultTypes.push_back(convertedType); + } + if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), *converter, + &conversion))) + return failure(); + + // Update the signature of the function. + rewriter.updateRootInPlace(funcOp, [&] { + funcOp.setType(rewriter.getFunctionType(conversion.getConvertedTypes(), + newResultTypes)); + }); + return success(); + } }; /// Rewrites the `ReturnOp` to conform with the changed function signature. -/// Operands that correspond to return values and their types have been set to -/// AppendToArgumentsList are dropped. In their place, a corresponding copy -/// operation from the operand to the target function argument is inserted. +/// if allowMemrefFunctionResults is false, operands that correspond to return +/// values and have been rewritten from illegal typed results to memref +/// arguments are dropped. In their place, a corresponding copy operation from +/// the operand to the output function argument is inserted. Otherwise, the +/// memref typed operands are returned. +/// Note: If this pattern rewriter is used with BufferAssignmentFuncOpConverter, +/// allowMemrefFunctionResults must be set/unset for both. template + typename CopyOpTy, bool allowMemrefFunctionResults> class BufferAssignmentReturnOpConverter : public BufferAssignmentOpConversionPattern { public: @@ -216,48 +174,44 @@ public: LogicalResult matchAndRewrite(ReturnOpSourceTy returnOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { - Location loc = returnOp.getLoc(); - - // Split the operands depending on whether they need a copy operation or - // they remain as operands of the return operation. If an operand is - // decomposable and a decompose callback function has been provided by the - // user, it will be unpacked. - SmallVector newOperands, needCopyOperands; - OpBuilder builder(returnOp); - for (auto operand : llvm::enumerate(operands)) { - SmallVector values; - this->converter->tryDecomposeValue( - builder, loc, operand.value().getType(), operand.value(), values); - Type type = returnOp.getOperand(operand.index()).getType(); - SmallVector originTypes; - this->converter->tryDecomposeType(type, originTypes); - for (auto value : llvm::enumerate(values)) { - Type origin = originTypes[value.index()]; - Type converted = value.value().getType(); - auto kind = this->converter->getResultConversionKind(origin, converted); - if (kind == BufferAssignmentTypeConverter::KeepAsFunctionResult) - newOperands.push_back(value.value()); - else - // kind = BufferAssignmentTypeConverter::AppendToArgumentsList - needCopyOperands.push_back(value.value()); - } + // If the memref typed results can be returned as function results, the new + // `ReturnOp` should only return the type converted operands. + if (allowMemrefFunctionResults) { + rewriter.replaceOpWithNewOp(returnOp, operands); + return success(); } - // Insert Copy operations instead for the operands that have been removed - // from operand list and appended to the function arguments list. + // Split the operands by their kinds whether they are converted memref or + // not. + SmallVector needCopyOperands, newOperands; + unsigned operandsSize = operands.size(); + needCopyOperands.reserve(operandsSize); + newOperands.reserve(operandsSize); + for (auto operand : llvm::enumerate(operands)) + if (BufferAssignmentTypeConverter::isConvertedMemref( + operand.value().getType(), + returnOp.getOperand(operand.index()).getType())) + needCopyOperands.push_back(operand.value()); + else + newOperands.push_back(operand.value()); + Block &entryBlock = returnOp.getParentRegion()->front(); unsigned numFuncArgs = entryBlock.getNumArguments(); - if (needCopyOperands.size() > numFuncArgs) - return returnOp.emitError( - "The number of operands that need Copy operations is more " - "than the number of target function arguments."); + + // Find the index of the first destination buffer. + assert(needCopyOperands.size() <= numFuncArgs && + "The number of operands of return operation is more than the " + "number of function arguments."); unsigned destArgNum = numFuncArgs - needCopyOperands.size(); rewriter.setInsertionPoint(returnOp); for (Value operand : needCopyOperands) { - rewriter.create(loc, operand, + // Insert a `CopyOp` for each converted memref-type operand. + rewriter.create(returnOp.getLoc(), operand, entryBlock.getArgument(destArgNum)); ++destArgNum; } + + // Insert the new target Return operation. rewriter.replaceOpWithNewOp(returnOp, newOperands); return success(); } @@ -265,32 +219,94 @@ public: /// Rewrites the `CallOp` to match its operands and results with the signature /// of the callee after rewriting the callee with -/// BufferAssignmentFuncOpConverter. +/// BufferAssignmentFuncOpConverter. If allowMemrefFunctionResults is false, a +/// buffer is allocated as an output buffer only for each memref typed result +/// that has been rewritten. The new allocated buffer is passed through the +/// operands list of the new `CallOp`. +/// Note: If this pattern rewriter is used with BufferAssignmentFuncOpConverter, +/// allowMemrefFunctionResults must be set/unset for both. +template class BufferAssignmentCallOpConverter : public BufferAssignmentOpConversionPattern { public: using BufferAssignmentOpConversionPattern< CallOp>::BufferAssignmentOpConversionPattern; - /// Performs the actual rewriting step. - LogicalResult matchAndRewrite(CallOp, ArrayRef, - ConversionPatternRewriter &) const; + LogicalResult + matchAndRewrite(CallOp callOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + if (!converter) + return callOp.emitError("The type converter has not been defined for " + "BufferAssignmentCallOpConverter"); + Location loc = callOp.getLoc(); + + // If the memref typed results can be returned as function results, there is + // no need to create output buffers. It is only required to convert the type + // of operands and results in place for creating the new `CallOp`. + if (allowMemrefFunctionResults) { + SmallVector resultTypes; + resultTypes.reserve(callOp.getNumResults()); + for (Type type : callOp.getResultTypes()) + resultTypes.push_back(converter->convertType(type)); + rewriter.replaceOpWithNewOp(callOp, callOp.getCallee(), + resultTypes, operands); + return success(); + } + + SmallVector newOperands, replacingValues; + SmallVector newResultTypes; + unsigned numResults = callOp.getNumResults(); + newOperands.reserve(numResults + operands.size()); + newOperands.append(operands.begin(), operands.end()); + newResultTypes.reserve(numResults); + replacingValues.reserve(numResults); + + // For each memref result of `CallOp` which has not been a memref before + // the type conversion, a new buffer is allocated and passed to the operands + // list of the new `CallOp`. Otherwise, it remains as a caller result. + for (Value result : callOp.getResults()) { + Type currType = result.getType(); + Type newType = converter->convertType(result.getType()); + if (BufferAssignmentTypeConverter::isConvertedMemref(newType, currType)) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.restoreInsertionPoint(bufferAssignment->computeAllocPosition( + result.dyn_cast())); + Value alloc = + rewriter.create(loc, newType.dyn_cast()); + newOperands.push_back(alloc); + replacingValues.push_back(alloc); + } else { + newResultTypes.push_back(currType); + + // No replacing is required. + replacingValues.push_back(nullptr); + } + } + + // Creating the new `CallOp`. + rewriter.create(loc, callOp.getCallee(), newResultTypes, + newOperands); + + // Replacing the results of the old `CallOp`. + rewriter.replaceOp(callOp, replacingValues); + return success(); + } }; +} // end namespace detail /// Populates `patterns` with the conversion patterns of buffer /// assignment. template + typename CopyOpTy, bool allowMemrefFunctionResults> static void populateWithBufferAssignmentOpConversionPatterns( MLIRContext *context, BufferAssignmentPlacer *placer, - BufferAssignmentTypeConverter *converter, - OwningRewritePatternList *patterns) { + TypeConverter *converter, OwningRewritePatternList *patterns) { // clang-format off patterns->insert< - BufferAssignmentCallOpConverter, - BufferAssignmentFuncOpConverter, - BufferAssignmentReturnOpConverter - + detail::BufferAssignmentCallOpConverter, + detail::BufferAssignmentFuncOpConverter, + detail::BufferAssignmentReturnOpConverter + >(context, placer, converter); // clang-format on } diff --git a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp index 89a01f9ca629..04c1fbd5d565 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp @@ -100,11 +100,11 @@ public: /// tensors to buffers. static void populateConvertLinalgOnTensorsToBuffersPattern( MLIRContext *context, BufferAssignmentPlacer *placer, - BufferAssignmentTypeConverter *converter, - OwningRewritePatternList *patterns) { + TypeConverter *converter, OwningRewritePatternList *patterns) { populateWithBufferAssignmentOpConversionPatterns< - mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp>(context, placer, - converter, patterns); + mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp, + /*allowMemrefFunctionResults=*/false>(context, placer, converter, + patterns); patterns->insert(context, placer, converter); } @@ -141,9 +141,6 @@ struct ConvertLinalgOnTensorsToBuffers converter.isLegal(&funcOp.getBody()); }); - converter.setResultConversionKind( - BufferAssignmentTypeConverter::AppendToArgumentsList); - // Walk over all the functions to apply buffer assignment. getOperation().walk([&](FuncOp function) -> WalkResult { OwningRewritePatternList patterns; diff --git a/mlir/lib/Transforms/BufferPlacement.cpp b/mlir/lib/Transforms/BufferPlacement.cpp index 1ab3e7e2e48d..201570a244ff 100644 --- a/mlir/lib/Transforms/BufferPlacement.cpp +++ b/mlir/lib/Transforms/BufferPlacement.cpp @@ -713,223 +713,9 @@ BufferAssignmentTypeConverter::BufferAssignmentTypeConverter() { }); } -/// This method tries to decompose a value of a certain type using provided -/// decompose callback functions. If it is unable to do so, the original value -/// is returned. -void BufferAssignmentTypeConverter::tryDecomposeValue( - OpBuilder &builder, Location loc, Type type, Value value, - SmallVectorImpl &results) { - for (auto conversion : decomposeValueConversions) - if (conversion(builder, loc, type, value, results) != llvm::None) - return; - results.push_back(value); -} - -/// This method tries to decompose a type using provided decompose callback -/// functions. If it is unable to do so, the original type is returned. -void BufferAssignmentTypeConverter::tryDecomposeType( - Type type, SmallVectorImpl &types) { - for (auto conversion : decomposeTypeConversions) - if (conversion(type, types) != llvm::None) - return; - types.push_back(type); -} - -/// This method returns ResultConversionKind for the input type. -BufferAssignmentTypeConverter::ResultConversionKind -BufferAssignmentTypeConverter::getResultConversionKind(Type origin, - Type converted) { - for (auto conversion : resultTypeConversions) { - auto res = conversion(origin, converted); - if (res != llvm::None) - return res.getValue(); - } - return KeepAsFunctionResult; -} - -//===----------------------------------------------------------------------===// -// BufferAssignmentFuncOpConverter -//===----------------------------------------------------------------------===// - -/// Performs the actual function signature rewriting step. -LogicalResult BufferAssignmentFuncOpConverter::matchAndRewrite( - mlir::FuncOp funcOp, ArrayRef operands, - ConversionPatternRewriter &rewriter) const { - auto funcType = funcOp.getType(); - - // Convert function arguments using the provided TypeConverter. - TypeConverter::SignatureConversion conversion(funcType.getNumInputs()); - for (auto argType : llvm::enumerate(funcType.getInputs())) { - SmallVector decomposedTypes, convertedTypes; - converter->tryDecomposeType(argType.value(), decomposedTypes); - converter->convertTypes(decomposedTypes, convertedTypes); - conversion.addInputs(argType.index(), convertedTypes); - } - - // Convert the result types of the function. - SmallVector newResultTypes; - newResultTypes.reserve(funcOp.getNumResults()); - for (Type resultType : funcType.getResults()) { - SmallVector originTypes; - converter->tryDecomposeType(resultType, originTypes); - for (auto origin : originTypes) { - Type converted = converter->convertType(origin); - auto kind = converter->getResultConversionKind(origin, converted); - if (kind == BufferAssignmentTypeConverter::AppendToArgumentsList) - conversion.addInputs(converted); - else - // kind = BufferAssignmentTypeConverter::KeepAsFunctionResult - newResultTypes.push_back(converted); - } - } - - if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), *converter, - &conversion))) - return failure(); - - // Update the signature of the function. - rewriter.updateRootInPlace(funcOp, [&] { - funcOp.setType(rewriter.getFunctionType(conversion.getConvertedTypes(), - newResultTypes)); - }); - return success(); -} - -//===----------------------------------------------------------------------===// -// BufferAssignmentCallOpConverter -//===----------------------------------------------------------------------===// - -/// Performs the actual rewriting step. -LogicalResult BufferAssignmentCallOpConverter::matchAndRewrite( - CallOp callOp, ArrayRef operands, - ConversionPatternRewriter &rewriter) const { - - // This class represents a mapping from a result to a list of values and some - // results that have not yet constructed. Instead, the indices of these - // results in the operation that will be constructed are known. They will be - // replaced with the actual values when they are available. The order of - // adding to this mapping is important. - class ResultMapping { - public: - ResultMapping() { order = 0; }; - - /// Add an available value to the mapping. - void addMapping(Value value) { - toValuesMapping.push_back({order++, value}); - } - - /// Add the index of unavailble result value to the mapping. - void addMapping(unsigned index) { - toIndicesMapping.push_back({order++, index}); - } - - /// This method returns the mapping values list. The unknown result values - /// that only their indicies are available are replaced with their values. - void getMappingValues(ValueRange valuesToReplaceIndices, - SmallVectorImpl &values) { - // Append available values to the list. - SmallVector, 2> res(toValuesMapping.begin(), - toValuesMapping.end()); - // Replace the indices with the actual values. - llvm::for_each( - toIndicesMapping, [&](const std::pair &entry) { - assert(entry.second < valuesToReplaceIndices.size() && - "The value index is out of range."); - res.push_back({entry.first, valuesToReplaceIndices[entry.second]}); - }); - // Sort the values based on their adding orders. - llvm::sort(res, [](const std::pair &v1, - const std::pair &v2) { - return v1.first < v2.first; - }); - // Fill the values. - llvm::for_each(res, [&](const std::pair &entry) { - values.push_back(entry.second); - }); - } - - private: - /// Keeping the inserting order of mapping values. - int order; - - /// Containing the mapping values with their inserting orders. - SmallVector, 2> toValuesMapping; - - /// Containing the indices of result values with their inserting orders. - SmallVector, 2> toIndicesMapping; - }; - - Location loc = callOp.getLoc(); - OpBuilder builder(callOp); - SmallVector newOperands; - - // Create the operands list of the new `CallOp`. It unpacks the decomposable - // values if a decompose callback function has been provided by the user. - for (auto operand : operands) { - SmallVector values; - this->converter->tryDecomposeValue(builder, loc, operand.getType(), operand, - values); - newOperands.append(values.begin(), values.end()); - } - - // Create the new result types for the new `CallOp` and a mapping from the old - // result to new value(s). - SmallVector newResultTypes; - SmallVector mappings; - mappings.resize(callOp.getNumResults()); - for (auto result : llvm::enumerate(callOp.getResults())) { - SmallVector originTypes; - converter->tryDecomposeType(result.value().getType(), originTypes); - auto &resultMapping = mappings[result.index()]; - for (Type origin : originTypes) { - Type converted = converter->convertType(origin); - auto kind = converter->getResultConversionKind(origin, converted); - if (kind == BufferAssignmentTypeConverter::KeepAsFunctionResult) { - newResultTypes.push_back(converted); - // The result value is not yet available. Its index is kept and it is - // replaced with the actual value of the new `CallOp` later. - resultMapping.addMapping(newResultTypes.size() - 1); - } else { - // kind = BufferAssignmentTypeConverter::AppendToArgumentsList - OpBuilder::InsertionGuard guard(rewriter); - rewriter.restoreInsertionPoint( - bufferAssignment->computeAllocPosition(result.value())); - MemRefType memref = converted.dyn_cast(); - if (!memref) - return callOp.emitError("Cannot allocate for a non-Memref type"); - Value alloc = rewriter.create(loc, memref); - newOperands.push_back(alloc); - resultMapping.addMapping(alloc); - } - } - } - - CallOp newCallOp = rewriter.create(loc, callOp.getCallee(), - newResultTypes, newOperands); - - // Build a replacing value for each result to replace its uses. If a result - // has multiple mapping values, it needs to be packed to a single value. - OpBuilder nextBuilder(callOp.getOperation()->getNextNode()); - SmallVector replacedValues; - replacedValues.reserve(callOp.getNumResults()); - for (unsigned i = 0, e = callOp.getNumResults(); i < e; ++i) { - SmallVector valuesToPack; - mappings[i].getMappingValues(newCallOp.getResults(), valuesToPack); - if (valuesToPack.empty()) { - // No replacement is required. - replacedValues.push_back(nullptr); - } else if (valuesToPack.size() == 1) { - replacedValues.push_back(valuesToPack.front()); - } else { - // Values need to be packed using callback function. The same callback - // that is used for materializeArgumentConversion is used for packing. - Value packed = converter->materializeArgumentConversion( - nextBuilder, loc, callOp.getType(i), valuesToPack); - replacedValues.push_back(packed); - } - } - rewriter.replaceOp(callOp, replacedValues); - return success(); +/// Checks if `type` has been converted from non-memref type to memref. +bool BufferAssignmentTypeConverter::isConvertedMemref(Type type, Type before) { + return type.isa() && !before.isa(); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Transforms/buffer-placement-preparation-allowed-memref-results.mlir b/mlir/test/Transforms/buffer-placement-preparation-allowed-memref-results.mlir index e1dacdf0184e..084ac38af6e3 100644 --- a/mlir/test/Transforms/buffer-placement-preparation-allowed-memref-results.mlir +++ b/mlir/test/Transforms/buffer-placement-preparation-allowed-memref-results.mlir @@ -111,73 +111,7 @@ func @caller(%arg0: tensor<5xf32>) -> tensor<5xf32> { // CHECK: %[[Y:.*]]:2 = call @callee(%[[X]]#0) // CHECK: return %[[Y]]#0 -// ----- -// Test case: Testing BufferAssginmnetCallOpConverter to see if it matches with the -// signature of the new signature of the callee function when there are tuple typed -// args and results. BufferAssginmentTypeConverter is set to flatten tuple typed -// arguments. The tuple typed values should be decomposed and composed using -// get_tuple_element and make_tuple operations of test dialect. Tensor types are -// converted to Memref. Memref typed function results remain as function results. -// CHECK-LABEL: func @callee -func @callee(%arg0: tuple,i1, tensor<5xf32>>) -> (tuple,i1, tensor<5xf32>>){ - return %arg0 : tuple,i1, tensor<5xf32>> -} -// CHECK-SAME: (%[[ARG0:.*]]: memref<2xf32>, %[[ARG1:.*]]: i1, %[[ARG2:.*]]: memref<5xf32>) -// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>) -// CHECK-NEXT: %[[TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) -// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 0 : i32} -// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 1 : i32} -// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 2 : i32} -// CHECK-NEXT: return %[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]] -// CHECK-LABEL: func @caller -func @caller(%arg0: tuple,i1, tensor<5xf32>>) -> tuple,i1, tensor<5xf32>>{ - %x0 = call @callee(%arg0) : (tuple,i1, tensor<5xf32>>) -> (tuple,i1, tensor<5xf32>>) - %y0 = call @callee(%x0) : (tuple,i1, tensor<5xf32>>) -> (tuple,i1, tensor<5xf32>>) - return %y0 : tuple,i1, tensor<5xf32>> -} -// CHECK-SAME: (%[[ARG0:.*]]: memref<2xf32>, %[[ARG1:.*]]: i1, %[[ARG2:.*]]: memref<5xf32>) -// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>) -// CHECK-NEXT: %[[ARG_TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) -// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) {index = 0 : i32} -// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) {index = 1 : i32} -// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) {index = 2 : i32} -// CHECK-NEXT: %[[CALLEE_RESULTS:.*]]:3 = call @callee(%[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]]) -// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>) -> (memref<2xf32>, i1, memref<5xf32>) -// CHECK-NEXT: %[[RESULT_TUPLE:.*]] = "test.make_tuple"(%[[CALLEE_RESULTS]]#0, %[[CALLEE_RESULTS]]#1, %[[CALLEE_RESULTS]]#2) -// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[RESULT_TUPLE]]) {index = 0 : i32} -// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[RESULT_TUPLE]]) {index = 1 : i32} -// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[RESULT_TUPLE]]) {index = 2 : i32} -// CHECK-NEXT: %[[CALLEE_RESULTS:.*]]:3 = call @callee(%[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]]) -// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>) -> (memref<2xf32>, i1, memref<5xf32>) -// CHECK-NEXT: %[[RETURN_TUPLE:.*]] = "test.make_tuple"(%[[CALLEE_RESULTS]]#0, %[[CALLEE_RESULTS]]#1, %[[CALLEE_RESULTS]]#2) -// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[RETURN_TUPLE]]) {index = 0 : i32} -// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[RETURN_TUPLE]]) {index = 1 : i32} -// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[RETURN_TUPLE]]) {index = 2 : i32} -// CHECK-NEXT: return %[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]] -// ----- - -// Test case: Testing BufferAssginmnetFuncOpConverter and -// BufferAssginmentReturnOpConverter to see if the return operation matches with -// the new function signature when there are tuple typed args and results. -// BufferAssginmentTypeConverter is set to flatten tuple typed arguments. The tuple -// typed values should be decomposed and composed using get_tuple_element and -// make_tuple operations of test dialect. Tensor types are converted to Memref. -// Memref typed function results remain as function results. - -// CHECK-LABEL: func @decompose_tuple_typed_function_args_and_results -func @decompose_tuple_typed_function_args_and_results(%arg0: tuple, %arg1: tensor<10xf32>, %arg2: tuple>) -> (tuple>, tensor<10xf32>, tuple){ - return %arg2, %arg1, %arg0 : tuple>, tensor<10xf32>, tuple -} -// CHECK-SAME: %[[ARG0:.*]]: i1, %[[ARG1:.*]]: f32, %[[ARG2:.*]]: memref<10xf32>, %[[ARG3:.*]]: i1, %[[ARG4:.*]]: memref<5xf32> -// CHECK-SAME: (i1, memref<5xf32>, memref<10xf32>, i1, f32) -// CHECK-NEXT: %[[FIRST_TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]]) -// CHECK-NEXT: %[[SECOND_TUPLE:.*]] = "test.make_tuple"(%[[ARG3]], %[[ARG4]]) -// CHECK-NEXT: %[[SECOND_TUPLE_FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[SECOND_TUPLE]]) {index = 0 : i32} -// CHECK-NEXT: %[[SECOND_TUPLE_SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[SECOND_TUPLE]]) {index = 1 : i32} -// CHECK-NEXT: %[[FIRST_TUPLE_FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[FIRST_TUPLE]]) {index = 0 : i32} -// CHECK-NEXT: %[[FIRST_TUPLE_SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[FIRST_TUPLE]]) {index = 1 : i32} -// CHECK-NEXT: return %[[SECOND_TUPLE_FIRST_ELEM]], %[[SECOND_TUPLE_SECOND_ELEM]], %[[ARG2]], %[[FIRST_TUPLE_FIRST_ELEM]], %[[FIRST_TUPLE_SECOND_ELEM]] diff --git a/mlir/test/Transforms/buffer-placement-preparation.mlir b/mlir/test/Transforms/buffer-placement-preparation.mlir index b1cfdfd690cf..064b0fd7e85a 100644 --- a/mlir/test/Transforms/buffer-placement-preparation.mlir +++ b/mlir/test/Transforms/buffer-placement-preparation.mlir @@ -285,93 +285,8 @@ func @caller(%arg0: tensor<5xf32>) -> tensor<5xf32> { // CHECK: linalg.copy(%[[Y0]], %[[CALLER_RESULT]]) // CHECK: return -// ----- - // CHECK-LABEL: func @func_with_unranked_arg func @func_with_unranked_arg(%arg0: tensor<*xf32>) { return } // CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>) - -// ----- - -// Test case: Testing BufferAssginmnetCallOpConverter to see if it matches with the -// signature of the new signature of the callee function when there are tuple typed -// args and results. BufferAssginmentTypeConverter is set to flatten tuple typed -// arguments. The tuple typed values should be decomposed and composed using -// get_tuple_element and make_tuple operations of test dialect. Tensor types are -// converted to Memref. Memref typed function results are appended to the function -// arguments list. - -// CHECK-LABEL: func @callee -func @callee(%arg0: tuple,i1, tensor<5xf32>>) -> (tuple,i1, tensor<5xf32>>){ - return %arg0 : tuple,i1, tensor<5xf32>> -} -// CHECK-SAME: (%[[ARG0:.*]]: memref<2xf32>, %[[ARG1:.*]]: i1, %[[ARG2:.*]]: memref<5xf32>, %[[RESULT0:.*]]: memref<2xf32>, %[[RESULT1:.*]]: memref<5xf32>) -// CHECK-SAME: i1 -// CHECK-NEXT: %[[TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) -// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 0 : i32} -// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 1 : i32} -// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 2 : i32} -// CHECK-NEXT: linalg.copy(%[[FIRST_ELEM]], %[[RESULT0]]) -// CHECK-NEXT: linalg.copy(%[[THIRD_ELEM]], %[[RESULT1]]) -// CHECK-NEXT: return %[[SECOND_ELEM]] - - -// CHECK-LABEL: func @caller -func @caller(%arg0: tuple,i1, tensor<5xf32>>) -> tuple,i1, tensor<5xf32>>{ - %x0 = call @callee(%arg0) : (tuple,i1, tensor<5xf32>>) -> (tuple,i1, tensor<5xf32>>) - %y0 = call @callee(%x0) : (tuple,i1, tensor<5xf32>>) -> (tuple,i1, tensor<5xf32>>) - return %y0 : tuple,i1, tensor<5xf32>> -} -// CHECK-SAME: (%[[ARG0:.*]]: memref<2xf32>, %[[ARG1:.*]]: i1, %[[ARG2:.*]]: memref<5xf32>, %[[RESULT0:.*]]: memref<2xf32>, %[[RESULT1:.*]]: memref<5xf32>) -// CHECK-SAME: i1 -// CHECK-NEXT: %[[TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) -// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 0 : i32} -// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 1 : i32} -// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 2 : i32} -// CHECK-NEXT: %[[FIRST_ALLOC:.*]] = alloc() -// CHECK-NEXT: %[[SECOND_ALLOC:.*]] = alloc() -// CHECK-NEXT: %[[CALLEE_RESULT:.*]] = call @callee(%[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]], %[[FIRST_ALLOC]], %[[SECOND_ALLOC]]) -// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>, memref<2xf32>, memref<5xf32>) -> i1 -// CHECK-NEXT: %[[TUPLE:.*]] = "test.make_tuple"(%[[FIRST_ALLOC]], %[[CALLEE_RESULT]], %[[SECOND_ALLOC]]) -// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 0 : i32} -// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 1 : i32} -// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 2 : i32} -// CHECK-NEXT: %[[FIRST_ALLOC:.*]] = alloc() -// CHECK-NEXT: %[[SECOND_ALLOC:.*]] = alloc() -// CHECK-NEXT: %[[CALLEE_RESULT:.*]] = call @callee(%[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]], %[[FIRST_ALLOC]], %[[SECOND_ALLOC]]) -// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>, memref<2xf32>, memref<5xf32>) -> i1 -// CHECK-NEXT: %[[TUPLE:.*]] = "test.make_tuple"(%[[FIRST_ALLOC]], %[[CALLEE_RESULT]], %[[SECOND_ALLOC]]) -// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 0 : i32} -// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 1 : i32} -// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 2 : i32} -// CHECK-NEXT: linalg.copy(%[[FIRST_ELEM]], %[[RESULT0]]) -// CHECK-NEXT: linalg.copy(%[[THIRD_ELEM]], %[[RESULT1]]) -// CHECK-NEXT: return %[[SECOND_ELEM]] - -// ----- - -// Test case: Testing BufferAssginmnetFuncOpConverter and -// BufferAssginmentReturnOpConverter to see if the return operation matches with -// the new function signature when there are tuple typed args and results. -// BufferAssginmentTypeConverter is set to flatten tuple typed arguments. The tuple -// typed values should be decomposed and composed using get_tuple_element and -// make_tuple operations of test dialect. Tensor types are converted to Memref. -// Memref typed function results are appended to the function arguments list. - -// CHECK-LABEL: func @decompose_tuple_typed_function_args_and_results -func @decompose_tuple_typed_function_args_and_results(%arg0: tuple, %arg1: tensor<10xf32>, %arg2: tuple>) -> (tuple>, tensor<10xf32>, tuple){ - return %arg2, %arg1, %arg0 : tuple>, tensor<10xf32>, tuple -} -// CHECK-SAME: %[[ARG0:.*]]: i1, %[[ARG1:.*]]: f32, %[[ARG2:.*]]: memref<10xf32>, %[[ARG3:.*]]: i1, %[[ARG4:.*]]: memref<5xf32>, %[[RESULT0:.*]]: memref<5xf32>, %[[RESULT1:.*]]: memref<10xf32> -// CHECK-SAME: (i1, i1, f32) -// CHECK-NEXT: %[[FIRST_TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]]) -// CHECK-NEXT: %[[SECOND_TUPLE:.*]] = "test.make_tuple"(%[[ARG3]], %[[ARG4]]) -// CHECK-NEXT: %[[SECOND_TUPLE_FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[SECOND_TUPLE]]) {index = 0 : i32} -// CHECK-NEXT: %[[SECOND_TUPLE_SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[SECOND_TUPLE]]) {index = 1 : i32} -// CHECK-NEXT: %[[FIRST_TUPLE_FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[FIRST_TUPLE]]) {index = 0 : i32} -// CHECK-NEXT: %[[FIRST_TUPLE_SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[FIRST_TUPLE]]) {index = 1 : i32} -// CHECK-NEXT: linalg.copy(%[[SECOND_TUPLE_SECOND_ELEM]], %[[RESULT0]]) -// CHECK-NEXT: linalg.copy(%[[ARG2]], %[[RESULT1]]) -// CHECK-NEXT: return %[[SECOND_TUPLE_FIRST_ELEM]], %[[FIRST_TUPLE_FIRST_ELEM]], %[[FIRST_TUPLE_SECOND_ELEM]] diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index f03c953396a4..bc26a8659831 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1669,7 +1669,7 @@ def TableGenBuildOp5 : TEST_Op<"tblgen_build_5", let results = (outs AnyType:$result); let extraClassDeclaration = [{ - static LogicalResult inferReturnTypes(MLIRContext *, + static LogicalResult inferReturnTypes(MLIRContext *, Optional location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { @@ -1679,31 +1679,4 @@ def TableGenBuildOp5 : TEST_Op<"tblgen_build_5", }]; } -//===----------------------------------------------------------------------===// -// Test BufferPlacement -//===----------------------------------------------------------------------===// - -def GetTupleElementOp: TEST_Op<"get_tuple_element"> { - let description = [{ - Test op that returns a specified element of the tuple. - }]; - - let arguments = (ins - TupleOf<[AnyType]>, - I32Attr:$index - ); - let results = (outs AnyType); -} - -def MakeTupleOp: TEST_Op<"make_tuple"> { - let description = [{ - Test op that creates a tuple value from a list of values. - }]; - - let arguments = (ins - Variadic:$inputs - ); - let results = (outs TupleOf<[AnyType]>); -} - #endif // TEST_OPS diff --git a/mlir/test/lib/Transforms/TestBufferPlacement.cpp b/mlir/test/lib/Transforms/TestBufferPlacement.cpp index 14b72b9fc92a..6cc0924191cb 100644 --- a/mlir/test/lib/Transforms/TestBufferPlacement.cpp +++ b/mlir/test/lib/Transforms/TestBufferPlacement.cpp @@ -11,8 +11,6 @@ // //===----------------------------------------------------------------------===// -#include "TestDialect.h" -#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/IR/Function.h" #include "mlir/IR/Operation.h" @@ -111,16 +109,14 @@ struct TestBufferPlacementPreparationPass void populateTensorLinalgToBufferLinalgConversionPattern( MLIRContext *context, BufferAssignmentPlacer *placer, - BufferAssignmentTypeConverter *converter, - OwningRewritePatternList *patterns) { + TypeConverter *converter, OwningRewritePatternList *patterns) { populateWithBufferAssignmentOpConversionPatterns< - mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp>(context, placer, - converter, patterns); + mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp, + allowMemrefFunctionResults>(context, placer, converter, patterns); patterns->insert(context, placer, converter); } void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); registry.insert(); } @@ -131,8 +127,6 @@ struct TestBufferPlacementPreparationPass // Mark all Standard operations legal. target.addLegalDialect(); - target.addLegalOp(); - target.addLegalOp(); // Mark all Linalg operations illegal as long as they work on tensors. auto isLegalOperation = [&](Operation *op) { @@ -155,42 +149,6 @@ struct TestBufferPlacementPreparationPass converter.isLegal(&funcOp.getBody()); }); - auto kind = allowMemrefFunctionResults - ? BufferAssignmentTypeConverter::KeepAsFunctionResult - : BufferAssignmentTypeConverter::AppendToArgumentsList; - converter.setResultConversionKind(kind); - converter.setResultConversionKind( - kind); - - converter.addDecomposeTypeConversion( - [](TupleType tupleType, SmallVectorImpl &types) { - tupleType.getFlattenedTypes(types); - return success(); - }); - - converter.addArgumentMaterialization( - [](OpBuilder &builder, TupleType resultType, ValueRange inputs, - Location loc) -> Optional { - if (inputs.size() == 1) - return llvm::None; - TypeRange TypeRange = inputs.getTypes(); - SmallVector types(TypeRange.begin(), TypeRange.end()); - TupleType tuple = TupleType::get(types, builder.getContext()); - mlir::Value value = builder.create(loc, tuple, inputs); - return value; - }); - - converter.addDecomposeValueConversion([](OpBuilder &builder, Location loc, - TupleType resultType, Value value, - SmallVectorImpl &values) { - for (unsigned i = 0, e = resultType.size(); i < e; ++i) { - Value res = builder.create( - loc, resultType.getType(i), value, builder.getI32IntegerAttr(i)); - values.push_back(res); - } - return success(); - }); - // Walk over all the functions to apply buffer assignment. this->getOperation().walk([&](FuncOp function) -> WalkResult { OwningRewritePatternList patterns;