Revert "[mlir] Extend BufferAssignmentTypeConverter with result conversion callbacks"

This reverts commit 94f5d24877 because
of failing the following tests:

MLIR :: Dialect/Linalg/tensors-to-buffers.mlir
MLIR :: Transforms/buffer-placement-preparation-allowed-memref-results.mlir
MLIR :: Transforms/buffer-placement-preparation.mlir
This commit is contained in:
Lei Zhang 2020-09-02 09:24:36 -04:00
parent 6d36b22b21
commit 1b88bbf5eb
7 changed files with 191 additions and 612 deletions

View File

@ -52,111 +52,6 @@ private:
Operation *operation; Operation *operation;
}; };
/// A helper type converter class for using inside Buffer Assignment operation
/// conversion patterns. The default constructor keeps all the types intact
/// except for the ranked-tensor types which is converted to memref types.
class BufferAssignmentTypeConverter : public TypeConverter {
public:
/// This enum is for showing how buffer placement operation converters should
/// conduct with certain result type after type conversion. This value can be
/// set/get for each specific type using setResultConversionKind or
/// getResultConversionKind.
enum ResultConversionKind { AppendToArgumentsList, KeepAsFunctionResult };
BufferAssignmentTypeConverter();
/// This method tries to decompose a value of a certain type using provided
/// decompose callback functions. If it is unable to do so, the original value
/// is returned.
void tryDecomposeValue(OpBuilder &, Location, Type, Value,
SmallVectorImpl<Value> &);
/// This method tries to decompose a type using provided decompose callback
/// functions. If it is unable to do so, the original type is returned.
void tryDecomposeType(Type, SmallVectorImpl<Type> &);
/// This method registers a callback function that will be called to decompose
/// a value of a certain type into several values.
template <typename FnT,
typename T = typename llvm::function_traits<FnT>::template arg_t<2>>
void addDecomposeValueConversion(FnT &&callback) {
decomposeValueConversions.emplace_back(
wrapDecomposeValueConversionCallback<T>(std::forward<FnT>(callback)));
}
/// This method registers a callback function that will be called to decompose
/// a type into several types.
template <typename FnT,
typename T = typename llvm::function_traits<FnT>::template arg_t<0>>
void addDecomposeTypeConversion(FnT &&callback) {
auto wrapper =
wrapDecomposeTypeConversionCallback<T>(std::forward<FnT>(callback));
decomposeTypeConversions.emplace_back(wrapper);
addConversion(std::forward<FnT>(callback));
}
/// This method returns ResultConversionKind for the mapping from `origin`
/// type to `input` type.
ResultConversionKind getResultConversionKind(Type origin, Type input);
/// This method registers ResultConversionKind for the mapping from type 'T'
/// to type 'U'.
template <typename T, typename U>
void setResultConversionKind(ResultConversionKind kind) {
assert((kind != AppendToArgumentsList ||
llvm::is_one_of<U, MemRefType, UnrankedMemRefType>::value) &&
"Only the memref typed values can be set to be appended to the "
"function argument list at the moment");
resultTypeConversions.emplace_back(
[&](Type origin, Type input) -> Optional<ResultConversionKind> {
if (origin.template isa<T>() && input.template isa<U>())
return kind;
return llvm::None;
});
}
private:
using DecomposeValueConversionCallFn = std::function<Optional<LogicalResult>(
OpBuilder &, Location, Type, Value, SmallVectorImpl<Value> &)>;
using DecomposeTypeConversionCallFn =
std::function<Optional<LogicalResult>(Type, SmallVectorImpl<Type> &)>;
using ResultConversionKindFn =
std::function<Optional<ResultConversionKind>(Type, Type)>;
/// Generate a wrapper for the given decompose value conversion callback.
template <typename T, typename FnT>
DecomposeValueConversionCallFn
wrapDecomposeValueConversionCallback(FnT &&callback) {
return [callback = std::forward<FnT>(callback)](
OpBuilder &builder, Location loc, Type type, Value value,
SmallVectorImpl<Value> &newValues) -> Optional<LogicalResult> {
if (T derivedType = type.dyn_cast<T>())
return callback(builder, loc, derivedType, value, newValues);
return llvm::None;
};
}
/// Generate a wrapper for the given decompose type conversion callback.
template <typename T, typename FnT>
DecomposeTypeConversionCallFn
wrapDecomposeTypeConversionCallback(FnT &&callback) {
return [callback = std::forward<FnT>(callback)](
Type type,
SmallVectorImpl<Type> &results) -> Optional<LogicalResult> {
T derivedType = type.dyn_cast<T>();
if (!derivedType)
return llvm::None;
return callback(derivedType, results);
};
}
SmallVector<ResultConversionKindFn, 2> resultTypeConversions;
SmallVector<DecomposeValueConversionCallFn, 2> decomposeValueConversions;
SmallVector<DecomposeTypeConversionCallFn, 2> decomposeTypeConversions;
};
/// Helper conversion pattern that encapsulates a BufferAssignmentPlacer /// Helper conversion pattern that encapsulates a BufferAssignmentPlacer
/// instance. Sample usage: /// instance. Sample usage:
/// class CustomConversionPattern : public /// class CustomConversionPattern : public
@ -173,22 +68,43 @@ class BufferAssignmentOpConversionPattern
public: public:
explicit BufferAssignmentOpConversionPattern( explicit BufferAssignmentOpConversionPattern(
MLIRContext *context, BufferAssignmentPlacer *bufferAssignment = nullptr, MLIRContext *context, BufferAssignmentPlacer *bufferAssignment = nullptr,
BufferAssignmentTypeConverter *converter = nullptr, TypeConverter *converter = nullptr, PatternBenefit benefit = 1)
PatternBenefit benefit = 1)
: OpConversionPattern<SourceOp>(context, benefit), : OpConversionPattern<SourceOp>(context, benefit),
bufferAssignment(bufferAssignment), converter(converter) { bufferAssignment(bufferAssignment), converter(converter) {}
assert(converter && "The type converter has not been defined");
}
protected: protected:
BufferAssignmentPlacer *bufferAssignment; BufferAssignmentPlacer *bufferAssignment;
BufferAssignmentTypeConverter *converter; TypeConverter *converter;
}; };
/// Converts the signature of the function using BufferAssignmentTypeConverter. /// A helper type converter class for using inside Buffer Assignment operation
/// Each result type of the function is kept as a function result or appended to /// conversion patterns. The default constructor keeps all the types intact
/// the function arguments list based on ResultConversionKind for the converted /// except for the ranked-tensor types which is converted to memref types.
/// result type. class BufferAssignmentTypeConverter : public TypeConverter {
public:
BufferAssignmentTypeConverter();
/// A helper function to check if `type` has been converted from non-memref
/// type to memref.
static bool isConvertedMemref(Type type, Type before);
};
namespace detail {
/// Converts the signature of the function based on whether the function is
/// allowed to return memref typed results or not using
/// `allowMemrefFunctionResults` parameter. If this option is false, then it
/// adds an extra function argument as an output buffer for each function result
/// which is going to be a memref type only after type conversion. The
/// other function result types remain unchanged. If
/// `allowMemrefFunctionResults` is true, the types are converted in place.
/// Any changes in function signature need to be applied
/// to return and caller operations. `BufferAssignmentReturnOpConverter` and
/// `BufferAssignmentCallOpConverter` are two helper function that match the
/// return and caller operation with the new function signature. Furthermore,
/// `BufferAssignmentTypeConverter` is a helper `TypeConverter` for converting
/// tensor typed values to memref typed ones.
template <bool allowMemrefFunctionResults>
class BufferAssignmentFuncOpConverter class BufferAssignmentFuncOpConverter
: public BufferAssignmentOpConversionPattern<FuncOp> { : public BufferAssignmentOpConversionPattern<FuncOp> {
public: public:
@ -196,16 +112,58 @@ public:
FuncOp>::BufferAssignmentOpConversionPattern; FuncOp>::BufferAssignmentOpConversionPattern;
/// Performs the actual signature rewriting step. /// Performs the actual signature rewriting step.
LogicalResult matchAndRewrite(mlir::FuncOp, ArrayRef<Value>, LogicalResult
ConversionPatternRewriter &) const; matchAndRewrite(mlir::FuncOp funcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
if (!converter)
return funcOp.emitError("The type converter has not been defined for "
"BufferAssignmentFuncOpConverter");
auto funcType = funcOp.getType();
// Convert function arguments using the provided TypeConverter.
TypeConverter::SignatureConversion conversion(funcType.getNumInputs());
for (auto argType : llvm::enumerate(funcType.getInputs()))
conversion.addInputs(argType.index(),
converter->convertType(argType.value()));
// If allowMemrefFunctionResults is false and a function result type is not
// a memref but it would be a memref after type conversion, a new argument
// should be appended to the function arguments list for this result.
// Otherwise, it remains unchanged as a function result.
SmallVector<Type, 2> newResultTypes;
newResultTypes.reserve(funcOp.getNumResults());
for (Type resType : funcType.getResults()) {
Type convertedType = converter->convertType(resType);
if (!allowMemrefFunctionResults &&
BufferAssignmentTypeConverter::isConvertedMemref(convertedType,
resType))
conversion.addInputs(convertedType);
else
newResultTypes.push_back(convertedType);
}
if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), *converter,
&conversion)))
return failure();
// Update the signature of the function.
rewriter.updateRootInPlace(funcOp, [&] {
funcOp.setType(rewriter.getFunctionType(conversion.getConvertedTypes(),
newResultTypes));
});
return success();
}
}; };
/// Rewrites the `ReturnOp` to conform with the changed function signature. /// Rewrites the `ReturnOp` to conform with the changed function signature.
/// Operands that correspond to return values and their types have been set to /// if allowMemrefFunctionResults is false, operands that correspond to return
/// AppendToArgumentsList are dropped. In their place, a corresponding copy /// values and have been rewritten from illegal typed results to memref
/// operation from the operand to the target function argument is inserted. /// arguments are dropped. In their place, a corresponding copy operation from
/// the operand to the output function argument is inserted. Otherwise, the
/// memref typed operands are returned.
/// Note: If this pattern rewriter is used with BufferAssignmentFuncOpConverter,
/// allowMemrefFunctionResults must be set/unset for both.
template <typename ReturnOpSourceTy, typename ReturnOpTargetTy, template <typename ReturnOpSourceTy, typename ReturnOpTargetTy,
typename CopyOpTy> typename CopyOpTy, bool allowMemrefFunctionResults>
class BufferAssignmentReturnOpConverter class BufferAssignmentReturnOpConverter
: public BufferAssignmentOpConversionPattern<ReturnOpSourceTy> { : public BufferAssignmentOpConversionPattern<ReturnOpSourceTy> {
public: public:
@ -216,48 +174,44 @@ public:
LogicalResult LogicalResult
matchAndRewrite(ReturnOpSourceTy returnOp, ArrayRef<Value> operands, matchAndRewrite(ReturnOpSourceTy returnOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final { ConversionPatternRewriter &rewriter) const final {
Location loc = returnOp.getLoc(); // If the memref typed results can be returned as function results, the new
// `ReturnOp` should only return the type converted operands.
// Split the operands depending on whether they need a copy operation or if (allowMemrefFunctionResults) {
// they remain as operands of the return operation. If an operand is rewriter.replaceOpWithNewOp<ReturnOpTargetTy>(returnOp, operands);
// decomposable and a decompose callback function has been provided by the return success();
// user, it will be unpacked.
SmallVector<Value, 2> newOperands, needCopyOperands;
OpBuilder builder(returnOp);
for (auto operand : llvm::enumerate(operands)) {
SmallVector<Value, 2> values;
this->converter->tryDecomposeValue(
builder, loc, operand.value().getType(), operand.value(), values);
Type type = returnOp.getOperand(operand.index()).getType();
SmallVector<Type, 2> originTypes;
this->converter->tryDecomposeType(type, originTypes);
for (auto value : llvm::enumerate(values)) {
Type origin = originTypes[value.index()];
Type converted = value.value().getType();
auto kind = this->converter->getResultConversionKind(origin, converted);
if (kind == BufferAssignmentTypeConverter::KeepAsFunctionResult)
newOperands.push_back(value.value());
else
// kind = BufferAssignmentTypeConverter::AppendToArgumentsList
needCopyOperands.push_back(value.value());
}
} }
// Insert Copy operations instead for the operands that have been removed // Split the operands by their kinds whether they are converted memref or
// from operand list and appended to the function arguments list. // not.
SmallVector<Value, 2> needCopyOperands, newOperands;
unsigned operandsSize = operands.size();
needCopyOperands.reserve(operandsSize);
newOperands.reserve(operandsSize);
for (auto operand : llvm::enumerate(operands))
if (BufferAssignmentTypeConverter::isConvertedMemref(
operand.value().getType(),
returnOp.getOperand(operand.index()).getType()))
needCopyOperands.push_back(operand.value());
else
newOperands.push_back(operand.value());
Block &entryBlock = returnOp.getParentRegion()->front(); Block &entryBlock = returnOp.getParentRegion()->front();
unsigned numFuncArgs = entryBlock.getNumArguments(); unsigned numFuncArgs = entryBlock.getNumArguments();
if (needCopyOperands.size() > numFuncArgs)
return returnOp.emitError( // Find the index of the first destination buffer.
"The number of operands that need Copy operations is more " assert(needCopyOperands.size() <= numFuncArgs &&
"than the number of target function arguments."); "The number of operands of return operation is more than the "
"number of function arguments.");
unsigned destArgNum = numFuncArgs - needCopyOperands.size(); unsigned destArgNum = numFuncArgs - needCopyOperands.size();
rewriter.setInsertionPoint(returnOp); rewriter.setInsertionPoint(returnOp);
for (Value operand : needCopyOperands) { for (Value operand : needCopyOperands) {
rewriter.create<CopyOpTy>(loc, operand, // Insert a `CopyOp` for each converted memref-type operand.
rewriter.create<CopyOpTy>(returnOp.getLoc(), operand,
entryBlock.getArgument(destArgNum)); entryBlock.getArgument(destArgNum));
++destArgNum; ++destArgNum;
} }
// Insert the new target Return operation.
rewriter.replaceOpWithNewOp<ReturnOpTargetTy>(returnOp, newOperands); rewriter.replaceOpWithNewOp<ReturnOpTargetTy>(returnOp, newOperands);
return success(); return success();
} }
@ -265,32 +219,94 @@ public:
/// Rewrites the `CallOp` to match its operands and results with the signature /// Rewrites the `CallOp` to match its operands and results with the signature
/// of the callee after rewriting the callee with /// of the callee after rewriting the callee with
/// BufferAssignmentFuncOpConverter. /// BufferAssignmentFuncOpConverter. If allowMemrefFunctionResults is false, a
/// buffer is allocated as an output buffer only for each memref typed result
/// that has been rewritten. The new allocated buffer is passed through the
/// operands list of the new `CallOp`.
/// Note: If this pattern rewriter is used with BufferAssignmentFuncOpConverter,
/// allowMemrefFunctionResults must be set/unset for both.
template <bool allowMemrefFunctionResults>
class BufferAssignmentCallOpConverter class BufferAssignmentCallOpConverter
: public BufferAssignmentOpConversionPattern<CallOp> { : public BufferAssignmentOpConversionPattern<CallOp> {
public: public:
using BufferAssignmentOpConversionPattern< using BufferAssignmentOpConversionPattern<
CallOp>::BufferAssignmentOpConversionPattern; CallOp>::BufferAssignmentOpConversionPattern;
/// Performs the actual rewriting step. LogicalResult
LogicalResult matchAndRewrite(CallOp, ArrayRef<Value>, matchAndRewrite(CallOp callOp, ArrayRef<Value> operands,
ConversionPatternRewriter &) const; ConversionPatternRewriter &rewriter) const final {
if (!converter)
return callOp.emitError("The type converter has not been defined for "
"BufferAssignmentCallOpConverter");
Location loc = callOp.getLoc();
// If the memref typed results can be returned as function results, there is
// no need to create output buffers. It is only required to convert the type
// of operands and results in place for creating the new `CallOp`.
if (allowMemrefFunctionResults) {
SmallVector<Type, 2> resultTypes;
resultTypes.reserve(callOp.getNumResults());
for (Type type : callOp.getResultTypes())
resultTypes.push_back(converter->convertType(type));
rewriter.replaceOpWithNewOp<CallOp>(callOp, callOp.getCallee(),
resultTypes, operands);
return success();
}
SmallVector<Value, 2> newOperands, replacingValues;
SmallVector<Type, 2> newResultTypes;
unsigned numResults = callOp.getNumResults();
newOperands.reserve(numResults + operands.size());
newOperands.append(operands.begin(), operands.end());
newResultTypes.reserve(numResults);
replacingValues.reserve(numResults);
// For each memref result of `CallOp` which has not been a memref before
// the type conversion, a new buffer is allocated and passed to the operands
// list of the new `CallOp`. Otherwise, it remains as a caller result.
for (Value result : callOp.getResults()) {
Type currType = result.getType();
Type newType = converter->convertType(result.getType());
if (BufferAssignmentTypeConverter::isConvertedMemref(newType, currType)) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.restoreInsertionPoint(bufferAssignment->computeAllocPosition(
result.dyn_cast<OpResult>()));
Value alloc =
rewriter.create<AllocOp>(loc, newType.dyn_cast<MemRefType>());
newOperands.push_back(alloc);
replacingValues.push_back(alloc);
} else {
newResultTypes.push_back(currType);
// No replacing is required.
replacingValues.push_back(nullptr);
}
}
// Creating the new `CallOp`.
rewriter.create<CallOp>(loc, callOp.getCallee(), newResultTypes,
newOperands);
// Replacing the results of the old `CallOp`.
rewriter.replaceOp(callOp, replacingValues);
return success();
}
}; };
} // end namespace detail
/// Populates `patterns` with the conversion patterns of buffer /// Populates `patterns` with the conversion patterns of buffer
/// assignment. /// assignment.
template <typename ReturnOpSourceTy, typename ReturnOpTargetTy, template <typename ReturnOpSourceTy, typename ReturnOpTargetTy,
typename CopyOpTy> typename CopyOpTy, bool allowMemrefFunctionResults>
static void populateWithBufferAssignmentOpConversionPatterns( static void populateWithBufferAssignmentOpConversionPatterns(
MLIRContext *context, BufferAssignmentPlacer *placer, MLIRContext *context, BufferAssignmentPlacer *placer,
BufferAssignmentTypeConverter *converter, TypeConverter *converter, OwningRewritePatternList *patterns) {
OwningRewritePatternList *patterns) {
// clang-format off // clang-format off
patterns->insert< patterns->insert<
BufferAssignmentCallOpConverter, detail::BufferAssignmentCallOpConverter<allowMemrefFunctionResults>,
BufferAssignmentFuncOpConverter, detail::BufferAssignmentFuncOpConverter<allowMemrefFunctionResults>,
BufferAssignmentReturnOpConverter detail::BufferAssignmentReturnOpConverter
<ReturnOpSourceTy, ReturnOpTargetTy, CopyOpTy> <ReturnOpSourceTy, ReturnOpTargetTy, CopyOpTy, allowMemrefFunctionResults>
>(context, placer, converter); >(context, placer, converter);
// clang-format on // clang-format on
} }

View File

@ -100,11 +100,11 @@ public:
/// tensors to buffers. /// tensors to buffers.
static void populateConvertLinalgOnTensorsToBuffersPattern( static void populateConvertLinalgOnTensorsToBuffersPattern(
MLIRContext *context, BufferAssignmentPlacer *placer, MLIRContext *context, BufferAssignmentPlacer *placer,
BufferAssignmentTypeConverter *converter, TypeConverter *converter, OwningRewritePatternList *patterns) {
OwningRewritePatternList *patterns) {
populateWithBufferAssignmentOpConversionPatterns< populateWithBufferAssignmentOpConversionPatterns<
mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp>(context, placer, mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp,
converter, patterns); /*allowMemrefFunctionResults=*/false>(context, placer, converter,
patterns);
patterns->insert<GenericOpConverter>(context, placer, converter); patterns->insert<GenericOpConverter>(context, placer, converter);
} }
@ -141,9 +141,6 @@ struct ConvertLinalgOnTensorsToBuffers
converter.isLegal(&funcOp.getBody()); converter.isLegal(&funcOp.getBody());
}); });
converter.setResultConversionKind<RankedTensorType, MemRefType>(
BufferAssignmentTypeConverter::AppendToArgumentsList);
// Walk over all the functions to apply buffer assignment. // Walk over all the functions to apply buffer assignment.
getOperation().walk([&](FuncOp function) -> WalkResult { getOperation().walk([&](FuncOp function) -> WalkResult {
OwningRewritePatternList patterns; OwningRewritePatternList patterns;

View File

@ -713,223 +713,9 @@ BufferAssignmentTypeConverter::BufferAssignmentTypeConverter() {
}); });
} }
/// This method tries to decompose a value of a certain type using provided /// Checks if `type` has been converted from non-memref type to memref.
/// decompose callback functions. If it is unable to do so, the original value bool BufferAssignmentTypeConverter::isConvertedMemref(Type type, Type before) {
/// is returned. return type.isa<BaseMemRefType>() && !before.isa<BaseMemRefType>();
void BufferAssignmentTypeConverter::tryDecomposeValue(
OpBuilder &builder, Location loc, Type type, Value value,
SmallVectorImpl<Value> &results) {
for (auto conversion : decomposeValueConversions)
if (conversion(builder, loc, type, value, results) != llvm::None)
return;
results.push_back(value);
}
/// This method tries to decompose a type using provided decompose callback
/// functions. If it is unable to do so, the original type is returned.
void BufferAssignmentTypeConverter::tryDecomposeType(
Type type, SmallVectorImpl<Type> &types) {
for (auto conversion : decomposeTypeConversions)
if (conversion(type, types) != llvm::None)
return;
types.push_back(type);
}
/// This method returns ResultConversionKind for the input type.
BufferAssignmentTypeConverter::ResultConversionKind
BufferAssignmentTypeConverter::getResultConversionKind(Type origin,
Type converted) {
for (auto conversion : resultTypeConversions) {
auto res = conversion(origin, converted);
if (res != llvm::None)
return res.getValue();
}
return KeepAsFunctionResult;
}
//===----------------------------------------------------------------------===//
// BufferAssignmentFuncOpConverter
//===----------------------------------------------------------------------===//
/// Performs the actual function signature rewriting step.
LogicalResult BufferAssignmentFuncOpConverter::matchAndRewrite(
mlir::FuncOp funcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
auto funcType = funcOp.getType();
// Convert function arguments using the provided TypeConverter.
TypeConverter::SignatureConversion conversion(funcType.getNumInputs());
for (auto argType : llvm::enumerate(funcType.getInputs())) {
SmallVector<Type, 2> decomposedTypes, convertedTypes;
converter->tryDecomposeType(argType.value(), decomposedTypes);
converter->convertTypes(decomposedTypes, convertedTypes);
conversion.addInputs(argType.index(), convertedTypes);
}
// Convert the result types of the function.
SmallVector<Type, 2> newResultTypes;
newResultTypes.reserve(funcOp.getNumResults());
for (Type resultType : funcType.getResults()) {
SmallVector<Type, 2> originTypes;
converter->tryDecomposeType(resultType, originTypes);
for (auto origin : originTypes) {
Type converted = converter->convertType(origin);
auto kind = converter->getResultConversionKind(origin, converted);
if (kind == BufferAssignmentTypeConverter::AppendToArgumentsList)
conversion.addInputs(converted);
else
// kind = BufferAssignmentTypeConverter::KeepAsFunctionResult
newResultTypes.push_back(converted);
}
}
if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), *converter,
&conversion)))
return failure();
// Update the signature of the function.
rewriter.updateRootInPlace(funcOp, [&] {
funcOp.setType(rewriter.getFunctionType(conversion.getConvertedTypes(),
newResultTypes));
});
return success();
}
//===----------------------------------------------------------------------===//
// BufferAssignmentCallOpConverter
//===----------------------------------------------------------------------===//
/// Performs the actual rewriting step.
LogicalResult BufferAssignmentCallOpConverter::matchAndRewrite(
CallOp callOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
// This class represents a mapping from a result to a list of values and some
// results that have not yet constructed. Instead, the indices of these
// results in the operation that will be constructed are known. They will be
// replaced with the actual values when they are available. The order of
// adding to this mapping is important.
class ResultMapping {
public:
ResultMapping() { order = 0; };
/// Add an available value to the mapping.
void addMapping(Value value) {
toValuesMapping.push_back({order++, value});
}
/// Add the index of unavailble result value to the mapping.
void addMapping(unsigned index) {
toIndicesMapping.push_back({order++, index});
}
/// This method returns the mapping values list. The unknown result values
/// that only their indicies are available are replaced with their values.
void getMappingValues(ValueRange valuesToReplaceIndices,
SmallVectorImpl<Value> &values) {
// Append available values to the list.
SmallVector<std::pair<unsigned, Value>, 2> res(toValuesMapping.begin(),
toValuesMapping.end());
// Replace the indices with the actual values.
llvm::for_each(
toIndicesMapping, [&](const std::pair<unsigned, unsigned> &entry) {
assert(entry.second < valuesToReplaceIndices.size() &&
"The value index is out of range.");
res.push_back({entry.first, valuesToReplaceIndices[entry.second]});
});
// Sort the values based on their adding orders.
llvm::sort(res, [](const std::pair<unsigned, Value> &v1,
const std::pair<unsigned, Value> &v2) {
return v1.first < v2.first;
});
// Fill the values.
llvm::for_each(res, [&](const std::pair<unsigned, Value> &entry) {
values.push_back(entry.second);
});
}
private:
/// Keeping the inserting order of mapping values.
int order;
/// Containing the mapping values with their inserting orders.
SmallVector<std::pair<unsigned, Value>, 2> toValuesMapping;
/// Containing the indices of result values with their inserting orders.
SmallVector<std::pair<unsigned, unsigned>, 2> toIndicesMapping;
};
Location loc = callOp.getLoc();
OpBuilder builder(callOp);
SmallVector<Value, 2> newOperands;
// Create the operands list of the new `CallOp`. It unpacks the decomposable
// values if a decompose callback function has been provided by the user.
for (auto operand : operands) {
SmallVector<Value, 2> values;
this->converter->tryDecomposeValue(builder, loc, operand.getType(), operand,
values);
newOperands.append(values.begin(), values.end());
}
// Create the new result types for the new `CallOp` and a mapping from the old
// result to new value(s).
SmallVector<Type, 2> newResultTypes;
SmallVector<ResultMapping, 4> mappings;
mappings.resize(callOp.getNumResults());
for (auto result : llvm::enumerate(callOp.getResults())) {
SmallVector<Type, 2> originTypes;
converter->tryDecomposeType(result.value().getType(), originTypes);
auto &resultMapping = mappings[result.index()];
for (Type origin : originTypes) {
Type converted = converter->convertType(origin);
auto kind = converter->getResultConversionKind(origin, converted);
if (kind == BufferAssignmentTypeConverter::KeepAsFunctionResult) {
newResultTypes.push_back(converted);
// The result value is not yet available. Its index is kept and it is
// replaced with the actual value of the new `CallOp` later.
resultMapping.addMapping(newResultTypes.size() - 1);
} else {
// kind = BufferAssignmentTypeConverter::AppendToArgumentsList
OpBuilder::InsertionGuard guard(rewriter);
rewriter.restoreInsertionPoint(
bufferAssignment->computeAllocPosition(result.value()));
MemRefType memref = converted.dyn_cast<MemRefType>();
if (!memref)
return callOp.emitError("Cannot allocate for a non-Memref type");
Value alloc = rewriter.create<AllocOp>(loc, memref);
newOperands.push_back(alloc);
resultMapping.addMapping(alloc);
}
}
}
CallOp newCallOp = rewriter.create<CallOp>(loc, callOp.getCallee(),
newResultTypes, newOperands);
// Build a replacing value for each result to replace its uses. If a result
// has multiple mapping values, it needs to be packed to a single value.
OpBuilder nextBuilder(callOp.getOperation()->getNextNode());
SmallVector<Value, 2> replacedValues;
replacedValues.reserve(callOp.getNumResults());
for (unsigned i = 0, e = callOp.getNumResults(); i < e; ++i) {
SmallVector<Value, 2> valuesToPack;
mappings[i].getMappingValues(newCallOp.getResults(), valuesToPack);
if (valuesToPack.empty()) {
// No replacement is required.
replacedValues.push_back(nullptr);
} else if (valuesToPack.size() == 1) {
replacedValues.push_back(valuesToPack.front());
} else {
// Values need to be packed using callback function. The same callback
// that is used for materializeArgumentConversion is used for packing.
Value packed = converter->materializeArgumentConversion(
nextBuilder, loc, callOp.getType(i), valuesToPack);
replacedValues.push_back(packed);
}
}
rewriter.replaceOp(callOp, replacedValues);
return success();
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -111,73 +111,7 @@ func @caller(%arg0: tensor<5xf32>) -> tensor<5xf32> {
// CHECK: %[[Y:.*]]:2 = call @callee(%[[X]]#0) // CHECK: %[[Y:.*]]:2 = call @callee(%[[X]]#0)
// CHECK: return %[[Y]]#0 // CHECK: return %[[Y]]#0
// -----
// Test case: Testing BufferAssginmnetCallOpConverter to see if it matches with the
// signature of the new signature of the callee function when there are tuple typed
// args and results. BufferAssginmentTypeConverter is set to flatten tuple typed
// arguments. The tuple typed values should be decomposed and composed using
// get_tuple_element and make_tuple operations of test dialect. Tensor types are
// converted to Memref. Memref typed function results remain as function results.
// CHECK-LABEL: func @callee
func @callee(%arg0: tuple<tensor<2xf32>,i1, tensor<5xf32>>) -> (tuple<tensor<2xf32>,i1, tensor<5xf32>>){
return %arg0 : tuple<tensor<2xf32>,i1, tensor<5xf32>>
}
// CHECK-SAME: (%[[ARG0:.*]]: memref<2xf32>, %[[ARG1:.*]]: i1, %[[ARG2:.*]]: memref<5xf32>)
// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>)
// CHECK-NEXT: %[[TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]], %[[ARG2]])
// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 0 : i32}
// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 1 : i32}
// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 2 : i32}
// CHECK-NEXT: return %[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]]
// CHECK-LABEL: func @caller
func @caller(%arg0: tuple<tensor<2xf32>,i1, tensor<5xf32>>) -> tuple<tensor<2xf32>,i1, tensor<5xf32>>{
%x0 = call @callee(%arg0) : (tuple<tensor<2xf32>,i1, tensor<5xf32>>) -> (tuple<tensor<2xf32>,i1, tensor<5xf32>>)
%y0 = call @callee(%x0) : (tuple<tensor<2xf32>,i1, tensor<5xf32>>) -> (tuple<tensor<2xf32>,i1, tensor<5xf32>>)
return %y0 : tuple<tensor<2xf32>,i1, tensor<5xf32>>
}
// CHECK-SAME: (%[[ARG0:.*]]: memref<2xf32>, %[[ARG1:.*]]: i1, %[[ARG2:.*]]: memref<5xf32>)
// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>)
// CHECK-NEXT: %[[ARG_TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]], %[[ARG2]])
// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) {index = 0 : i32}
// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) {index = 1 : i32}
// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) {index = 2 : i32}
// CHECK-NEXT: %[[CALLEE_RESULTS:.*]]:3 = call @callee(%[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]])
// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>) -> (memref<2xf32>, i1, memref<5xf32>)
// CHECK-NEXT: %[[RESULT_TUPLE:.*]] = "test.make_tuple"(%[[CALLEE_RESULTS]]#0, %[[CALLEE_RESULTS]]#1, %[[CALLEE_RESULTS]]#2)
// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[RESULT_TUPLE]]) {index = 0 : i32}
// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[RESULT_TUPLE]]) {index = 1 : i32}
// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[RESULT_TUPLE]]) {index = 2 : i32}
// CHECK-NEXT: %[[CALLEE_RESULTS:.*]]:3 = call @callee(%[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]])
// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>) -> (memref<2xf32>, i1, memref<5xf32>)
// CHECK-NEXT: %[[RETURN_TUPLE:.*]] = "test.make_tuple"(%[[CALLEE_RESULTS]]#0, %[[CALLEE_RESULTS]]#1, %[[CALLEE_RESULTS]]#2)
// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[RETURN_TUPLE]]) {index = 0 : i32}
// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[RETURN_TUPLE]]) {index = 1 : i32}
// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[RETURN_TUPLE]]) {index = 2 : i32}
// CHECK-NEXT: return %[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]]
// -----
// Test case: Testing BufferAssginmnetFuncOpConverter and
// BufferAssginmentReturnOpConverter to see if the return operation matches with
// the new function signature when there are tuple typed args and results.
// BufferAssginmentTypeConverter is set to flatten tuple typed arguments. The tuple
// typed values should be decomposed and composed using get_tuple_element and
// make_tuple operations of test dialect. Tensor types are converted to Memref.
// Memref typed function results remain as function results.
// CHECK-LABEL: func @decompose_tuple_typed_function_args_and_results
func @decompose_tuple_typed_function_args_and_results(%arg0: tuple<i1,f32>, %arg1: tensor<10xf32>, %arg2: tuple<i1, tensor<5xf32>>) -> (tuple<i1, tensor<5xf32>>, tensor<10xf32>, tuple<i1,f32>){
return %arg2, %arg1, %arg0 : tuple<i1, tensor<5xf32>>, tensor<10xf32>, tuple<i1,f32>
}
// CHECK-SAME: %[[ARG0:.*]]: i1, %[[ARG1:.*]]: f32, %[[ARG2:.*]]: memref<10xf32>, %[[ARG3:.*]]: i1, %[[ARG4:.*]]: memref<5xf32>
// CHECK-SAME: (i1, memref<5xf32>, memref<10xf32>, i1, f32)
// CHECK-NEXT: %[[FIRST_TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]])
// CHECK-NEXT: %[[SECOND_TUPLE:.*]] = "test.make_tuple"(%[[ARG3]], %[[ARG4]])
// CHECK-NEXT: %[[SECOND_TUPLE_FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[SECOND_TUPLE]]) {index = 0 : i32}
// CHECK-NEXT: %[[SECOND_TUPLE_SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[SECOND_TUPLE]]) {index = 1 : i32}
// CHECK-NEXT: %[[FIRST_TUPLE_FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[FIRST_TUPLE]]) {index = 0 : i32}
// CHECK-NEXT: %[[FIRST_TUPLE_SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[FIRST_TUPLE]]) {index = 1 : i32}
// CHECK-NEXT: return %[[SECOND_TUPLE_FIRST_ELEM]], %[[SECOND_TUPLE_SECOND_ELEM]], %[[ARG2]], %[[FIRST_TUPLE_FIRST_ELEM]], %[[FIRST_TUPLE_SECOND_ELEM]]

View File

@ -285,93 +285,8 @@ func @caller(%arg0: tensor<5xf32>) -> tensor<5xf32> {
// CHECK: linalg.copy(%[[Y0]], %[[CALLER_RESULT]]) // CHECK: linalg.copy(%[[Y0]], %[[CALLER_RESULT]])
// CHECK: return // CHECK: return
// -----
// CHECK-LABEL: func @func_with_unranked_arg // CHECK-LABEL: func @func_with_unranked_arg
func @func_with_unranked_arg(%arg0: tensor<*xf32>) { func @func_with_unranked_arg(%arg0: tensor<*xf32>) {
return return
} }
// CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>) // CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>)
// -----
// Test case: Testing BufferAssginmnetCallOpConverter to see if it matches with the
// signature of the new signature of the callee function when there are tuple typed
// args and results. BufferAssginmentTypeConverter is set to flatten tuple typed
// arguments. The tuple typed values should be decomposed and composed using
// get_tuple_element and make_tuple operations of test dialect. Tensor types are
// converted to Memref. Memref typed function results are appended to the function
// arguments list.
// CHECK-LABEL: func @callee
func @callee(%arg0: tuple<tensor<2xf32>,i1, tensor<5xf32>>) -> (tuple<tensor<2xf32>,i1, tensor<5xf32>>){
return %arg0 : tuple<tensor<2xf32>,i1, tensor<5xf32>>
}
// CHECK-SAME: (%[[ARG0:.*]]: memref<2xf32>, %[[ARG1:.*]]: i1, %[[ARG2:.*]]: memref<5xf32>, %[[RESULT0:.*]]: memref<2xf32>, %[[RESULT1:.*]]: memref<5xf32>)
// CHECK-SAME: i1
// CHECK-NEXT: %[[TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]], %[[ARG2]])
// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 0 : i32}
// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 1 : i32}
// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 2 : i32}
// CHECK-NEXT: linalg.copy(%[[FIRST_ELEM]], %[[RESULT0]])
// CHECK-NEXT: linalg.copy(%[[THIRD_ELEM]], %[[RESULT1]])
// CHECK-NEXT: return %[[SECOND_ELEM]]
// CHECK-LABEL: func @caller
func @caller(%arg0: tuple<tensor<2xf32>,i1, tensor<5xf32>>) -> tuple<tensor<2xf32>,i1, tensor<5xf32>>{
%x0 = call @callee(%arg0) : (tuple<tensor<2xf32>,i1, tensor<5xf32>>) -> (tuple<tensor<2xf32>,i1, tensor<5xf32>>)
%y0 = call @callee(%x0) : (tuple<tensor<2xf32>,i1, tensor<5xf32>>) -> (tuple<tensor<2xf32>,i1, tensor<5xf32>>)
return %y0 : tuple<tensor<2xf32>,i1, tensor<5xf32>>
}
// CHECK-SAME: (%[[ARG0:.*]]: memref<2xf32>, %[[ARG1:.*]]: i1, %[[ARG2:.*]]: memref<5xf32>, %[[RESULT0:.*]]: memref<2xf32>, %[[RESULT1:.*]]: memref<5xf32>)
// CHECK-SAME: i1
// CHECK-NEXT: %[[TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]], %[[ARG2]])
// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 0 : i32}
// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 1 : i32}
// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 2 : i32}
// CHECK-NEXT: %[[FIRST_ALLOC:.*]] = alloc()
// CHECK-NEXT: %[[SECOND_ALLOC:.*]] = alloc()
// CHECK-NEXT: %[[CALLEE_RESULT:.*]] = call @callee(%[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]], %[[FIRST_ALLOC]], %[[SECOND_ALLOC]])
// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>, memref<2xf32>, memref<5xf32>) -> i1
// CHECK-NEXT: %[[TUPLE:.*]] = "test.make_tuple"(%[[FIRST_ALLOC]], %[[CALLEE_RESULT]], %[[SECOND_ALLOC]])
// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 0 : i32}
// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 1 : i32}
// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 2 : i32}
// CHECK-NEXT: %[[FIRST_ALLOC:.*]] = alloc()
// CHECK-NEXT: %[[SECOND_ALLOC:.*]] = alloc()
// CHECK-NEXT: %[[CALLEE_RESULT:.*]] = call @callee(%[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]], %[[FIRST_ALLOC]], %[[SECOND_ALLOC]])
// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>, memref<2xf32>, memref<5xf32>) -> i1
// CHECK-NEXT: %[[TUPLE:.*]] = "test.make_tuple"(%[[FIRST_ALLOC]], %[[CALLEE_RESULT]], %[[SECOND_ALLOC]])
// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 0 : i32}
// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 1 : i32}
// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 2 : i32}
// CHECK-NEXT: linalg.copy(%[[FIRST_ELEM]], %[[RESULT0]])
// CHECK-NEXT: linalg.copy(%[[THIRD_ELEM]], %[[RESULT1]])
// CHECK-NEXT: return %[[SECOND_ELEM]]
// -----
// Test case: Testing BufferAssginmnetFuncOpConverter and
// BufferAssginmentReturnOpConverter to see if the return operation matches with
// the new function signature when there are tuple typed args and results.
// BufferAssginmentTypeConverter is set to flatten tuple typed arguments. The tuple
// typed values should be decomposed and composed using get_tuple_element and
// make_tuple operations of test dialect. Tensor types are converted to Memref.
// Memref typed function results are appended to the function arguments list.
// CHECK-LABEL: func @decompose_tuple_typed_function_args_and_results
func @decompose_tuple_typed_function_args_and_results(%arg0: tuple<i1,f32>, %arg1: tensor<10xf32>, %arg2: tuple<i1, tensor<5xf32>>) -> (tuple<i1, tensor<5xf32>>, tensor<10xf32>, tuple<i1,f32>){
return %arg2, %arg1, %arg0 : tuple<i1, tensor<5xf32>>, tensor<10xf32>, tuple<i1,f32>
}
// CHECK-SAME: %[[ARG0:.*]]: i1, %[[ARG1:.*]]: f32, %[[ARG2:.*]]: memref<10xf32>, %[[ARG3:.*]]: i1, %[[ARG4:.*]]: memref<5xf32>, %[[RESULT0:.*]]: memref<5xf32>, %[[RESULT1:.*]]: memref<10xf32>
// CHECK-SAME: (i1, i1, f32)
// CHECK-NEXT: %[[FIRST_TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]])
// CHECK-NEXT: %[[SECOND_TUPLE:.*]] = "test.make_tuple"(%[[ARG3]], %[[ARG4]])
// CHECK-NEXT: %[[SECOND_TUPLE_FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[SECOND_TUPLE]]) {index = 0 : i32}
// CHECK-NEXT: %[[SECOND_TUPLE_SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[SECOND_TUPLE]]) {index = 1 : i32}
// CHECK-NEXT: %[[FIRST_TUPLE_FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[FIRST_TUPLE]]) {index = 0 : i32}
// CHECK-NEXT: %[[FIRST_TUPLE_SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[FIRST_TUPLE]]) {index = 1 : i32}
// CHECK-NEXT: linalg.copy(%[[SECOND_TUPLE_SECOND_ELEM]], %[[RESULT0]])
// CHECK-NEXT: linalg.copy(%[[ARG2]], %[[RESULT1]])
// CHECK-NEXT: return %[[SECOND_TUPLE_FIRST_ELEM]], %[[FIRST_TUPLE_FIRST_ELEM]], %[[FIRST_TUPLE_SECOND_ELEM]]

View File

@ -1669,7 +1669,7 @@ def TableGenBuildOp5 : TEST_Op<"tblgen_build_5",
let results = (outs AnyType:$result); let results = (outs AnyType:$result);
let extraClassDeclaration = [{ let extraClassDeclaration = [{
static LogicalResult inferReturnTypes(MLIRContext *, static LogicalResult inferReturnTypes(MLIRContext *,
Optional<Location> location, ValueRange operands, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions, DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) { SmallVectorImpl<Type> &inferredReturnTypes) {
@ -1679,31 +1679,4 @@ def TableGenBuildOp5 : TEST_Op<"tblgen_build_5",
}]; }];
} }
//===----------------------------------------------------------------------===//
// Test BufferPlacement
//===----------------------------------------------------------------------===//
def GetTupleElementOp: TEST_Op<"get_tuple_element"> {
let description = [{
Test op that returns a specified element of the tuple.
}];
let arguments = (ins
TupleOf<[AnyType]>,
I32Attr:$index
);
let results = (outs AnyType);
}
def MakeTupleOp: TEST_Op<"make_tuple"> {
let description = [{
Test op that creates a tuple value from a list of values.
}];
let arguments = (ins
Variadic<AnyType>:$inputs
);
let results = (outs TupleOf<[AnyType]>);
}
#endif // TEST_OPS #endif // TEST_OPS

View File

@ -11,8 +11,6 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "TestDialect.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/IR/Function.h" #include "mlir/IR/Function.h"
#include "mlir/IR/Operation.h" #include "mlir/IR/Operation.h"
@ -111,16 +109,14 @@ struct TestBufferPlacementPreparationPass
void populateTensorLinalgToBufferLinalgConversionPattern( void populateTensorLinalgToBufferLinalgConversionPattern(
MLIRContext *context, BufferAssignmentPlacer *placer, MLIRContext *context, BufferAssignmentPlacer *placer,
BufferAssignmentTypeConverter *converter, TypeConverter *converter, OwningRewritePatternList *patterns) {
OwningRewritePatternList *patterns) {
populateWithBufferAssignmentOpConversionPatterns< populateWithBufferAssignmentOpConversionPatterns<
mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp>(context, placer, mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp,
converter, patterns); allowMemrefFunctionResults>(context, placer, converter, patterns);
patterns->insert<GenericOpConverter>(context, placer, converter); patterns->insert<GenericOpConverter>(context, placer, converter);
} }
void getDependentDialects(DialectRegistry &registry) const override { void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<TestDialect>();
registry.insert<linalg::LinalgDialect>(); registry.insert<linalg::LinalgDialect>();
} }
@ -131,8 +127,6 @@ struct TestBufferPlacementPreparationPass
// Mark all Standard operations legal. // Mark all Standard operations legal.
target.addLegalDialect<StandardOpsDialect>(); target.addLegalDialect<StandardOpsDialect>();
target.addLegalOp<MakeTupleOp>();
target.addLegalOp<GetTupleElementOp>();
// Mark all Linalg operations illegal as long as they work on tensors. // Mark all Linalg operations illegal as long as they work on tensors.
auto isLegalOperation = [&](Operation *op) { auto isLegalOperation = [&](Operation *op) {
@ -155,42 +149,6 @@ struct TestBufferPlacementPreparationPass
converter.isLegal(&funcOp.getBody()); converter.isLegal(&funcOp.getBody());
}); });
auto kind = allowMemrefFunctionResults
? BufferAssignmentTypeConverter::KeepAsFunctionResult
: BufferAssignmentTypeConverter::AppendToArgumentsList;
converter.setResultConversionKind<RankedTensorType, MemRefType>(kind);
converter.setResultConversionKind<UnrankedTensorType, UnrankedMemRefType>(
kind);
converter.addDecomposeTypeConversion(
[](TupleType tupleType, SmallVectorImpl<Type> &types) {
tupleType.getFlattenedTypes(types);
return success();
});
converter.addArgumentMaterialization(
[](OpBuilder &builder, TupleType resultType, ValueRange inputs,
Location loc) -> Optional<Value> {
if (inputs.size() == 1)
return llvm::None;
TypeRange TypeRange = inputs.getTypes();
SmallVector<Type, 2> types(TypeRange.begin(), TypeRange.end());
TupleType tuple = TupleType::get(types, builder.getContext());
mlir::Value value = builder.create<MakeTupleOp>(loc, tuple, inputs);
return value;
});
converter.addDecomposeValueConversion([](OpBuilder &builder, Location loc,
TupleType resultType, Value value,
SmallVectorImpl<Value> &values) {
for (unsigned i = 0, e = resultType.size(); i < e; ++i) {
Value res = builder.create<GetTupleElementOp>(
loc, resultType.getType(i), value, builder.getI32IntegerAttr(i));
values.push_back(res);
}
return success();
});
// Walk over all the functions to apply buffer assignment. // Walk over all the functions to apply buffer assignment.
this->getOperation().walk([&](FuncOp function) -> WalkResult { this->getOperation().walk([&](FuncOp function) -> WalkResult {
OwningRewritePatternList patterns; OwningRewritePatternList patterns;