forked from OSchip/llvm-project
Revert "[mlir] Extend BufferAssignmentTypeConverter with result conversion callbacks"
This reverts commit 94f5d24877
because
of failing the following tests:
MLIR :: Dialect/Linalg/tensors-to-buffers.mlir
MLIR :: Transforms/buffer-placement-preparation-allowed-memref-results.mlir
MLIR :: Transforms/buffer-placement-preparation.mlir
This commit is contained in:
parent
6d36b22b21
commit
1b88bbf5eb
|
@ -52,111 +52,6 @@ private:
|
|||
Operation *operation;
|
||||
};
|
||||
|
||||
/// A helper type converter class for using inside Buffer Assignment operation
|
||||
/// conversion patterns. The default constructor keeps all the types intact
|
||||
/// except for the ranked-tensor types which is converted to memref types.
|
||||
class BufferAssignmentTypeConverter : public TypeConverter {
|
||||
public:
|
||||
/// This enum is for showing how buffer placement operation converters should
|
||||
/// conduct with certain result type after type conversion. This value can be
|
||||
/// set/get for each specific type using setResultConversionKind or
|
||||
/// getResultConversionKind.
|
||||
enum ResultConversionKind { AppendToArgumentsList, KeepAsFunctionResult };
|
||||
|
||||
BufferAssignmentTypeConverter();
|
||||
|
||||
/// This method tries to decompose a value of a certain type using provided
|
||||
/// decompose callback functions. If it is unable to do so, the original value
|
||||
/// is returned.
|
||||
void tryDecomposeValue(OpBuilder &, Location, Type, Value,
|
||||
SmallVectorImpl<Value> &);
|
||||
|
||||
/// This method tries to decompose a type using provided decompose callback
|
||||
/// functions. If it is unable to do so, the original type is returned.
|
||||
void tryDecomposeType(Type, SmallVectorImpl<Type> &);
|
||||
|
||||
/// This method registers a callback function that will be called to decompose
|
||||
/// a value of a certain type into several values.
|
||||
template <typename FnT,
|
||||
typename T = typename llvm::function_traits<FnT>::template arg_t<2>>
|
||||
void addDecomposeValueConversion(FnT &&callback) {
|
||||
decomposeValueConversions.emplace_back(
|
||||
wrapDecomposeValueConversionCallback<T>(std::forward<FnT>(callback)));
|
||||
}
|
||||
|
||||
/// This method registers a callback function that will be called to decompose
|
||||
/// a type into several types.
|
||||
template <typename FnT,
|
||||
typename T = typename llvm::function_traits<FnT>::template arg_t<0>>
|
||||
void addDecomposeTypeConversion(FnT &&callback) {
|
||||
auto wrapper =
|
||||
wrapDecomposeTypeConversionCallback<T>(std::forward<FnT>(callback));
|
||||
decomposeTypeConversions.emplace_back(wrapper);
|
||||
addConversion(std::forward<FnT>(callback));
|
||||
}
|
||||
|
||||
/// This method returns ResultConversionKind for the mapping from `origin`
|
||||
/// type to `input` type.
|
||||
ResultConversionKind getResultConversionKind(Type origin, Type input);
|
||||
|
||||
/// This method registers ResultConversionKind for the mapping from type 'T'
|
||||
/// to type 'U'.
|
||||
template <typename T, typename U>
|
||||
void setResultConversionKind(ResultConversionKind kind) {
|
||||
assert((kind != AppendToArgumentsList ||
|
||||
llvm::is_one_of<U, MemRefType, UnrankedMemRefType>::value) &&
|
||||
"Only the memref typed values can be set to be appended to the "
|
||||
"function argument list at the moment");
|
||||
resultTypeConversions.emplace_back(
|
||||
[&](Type origin, Type input) -> Optional<ResultConversionKind> {
|
||||
if (origin.template isa<T>() && input.template isa<U>())
|
||||
return kind;
|
||||
return llvm::None;
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
using DecomposeValueConversionCallFn = std::function<Optional<LogicalResult>(
|
||||
OpBuilder &, Location, Type, Value, SmallVectorImpl<Value> &)>;
|
||||
|
||||
using DecomposeTypeConversionCallFn =
|
||||
std::function<Optional<LogicalResult>(Type, SmallVectorImpl<Type> &)>;
|
||||
|
||||
using ResultConversionKindFn =
|
||||
std::function<Optional<ResultConversionKind>(Type, Type)>;
|
||||
|
||||
/// Generate a wrapper for the given decompose value conversion callback.
|
||||
template <typename T, typename FnT>
|
||||
DecomposeValueConversionCallFn
|
||||
wrapDecomposeValueConversionCallback(FnT &&callback) {
|
||||
return [callback = std::forward<FnT>(callback)](
|
||||
OpBuilder &builder, Location loc, Type type, Value value,
|
||||
SmallVectorImpl<Value> &newValues) -> Optional<LogicalResult> {
|
||||
if (T derivedType = type.dyn_cast<T>())
|
||||
return callback(builder, loc, derivedType, value, newValues);
|
||||
return llvm::None;
|
||||
};
|
||||
}
|
||||
|
||||
/// Generate a wrapper for the given decompose type conversion callback.
|
||||
template <typename T, typename FnT>
|
||||
DecomposeTypeConversionCallFn
|
||||
wrapDecomposeTypeConversionCallback(FnT &&callback) {
|
||||
return [callback = std::forward<FnT>(callback)](
|
||||
Type type,
|
||||
SmallVectorImpl<Type> &results) -> Optional<LogicalResult> {
|
||||
T derivedType = type.dyn_cast<T>();
|
||||
if (!derivedType)
|
||||
return llvm::None;
|
||||
return callback(derivedType, results);
|
||||
};
|
||||
}
|
||||
|
||||
SmallVector<ResultConversionKindFn, 2> resultTypeConversions;
|
||||
SmallVector<DecomposeValueConversionCallFn, 2> decomposeValueConversions;
|
||||
SmallVector<DecomposeTypeConversionCallFn, 2> decomposeTypeConversions;
|
||||
};
|
||||
|
||||
/// Helper conversion pattern that encapsulates a BufferAssignmentPlacer
|
||||
/// instance. Sample usage:
|
||||
/// class CustomConversionPattern : public
|
||||
|
@ -173,22 +68,43 @@ class BufferAssignmentOpConversionPattern
|
|||
public:
|
||||
explicit BufferAssignmentOpConversionPattern(
|
||||
MLIRContext *context, BufferAssignmentPlacer *bufferAssignment = nullptr,
|
||||
BufferAssignmentTypeConverter *converter = nullptr,
|
||||
PatternBenefit benefit = 1)
|
||||
TypeConverter *converter = nullptr, PatternBenefit benefit = 1)
|
||||
: OpConversionPattern<SourceOp>(context, benefit),
|
||||
bufferAssignment(bufferAssignment), converter(converter) {
|
||||
assert(converter && "The type converter has not been defined");
|
||||
}
|
||||
bufferAssignment(bufferAssignment), converter(converter) {}
|
||||
|
||||
protected:
|
||||
BufferAssignmentPlacer *bufferAssignment;
|
||||
BufferAssignmentTypeConverter *converter;
|
||||
TypeConverter *converter;
|
||||
};
|
||||
|
||||
/// Converts the signature of the function using BufferAssignmentTypeConverter.
|
||||
/// Each result type of the function is kept as a function result or appended to
|
||||
/// the function arguments list based on ResultConversionKind for the converted
|
||||
/// result type.
|
||||
/// A helper type converter class for using inside Buffer Assignment operation
|
||||
/// conversion patterns. The default constructor keeps all the types intact
|
||||
/// except for the ranked-tensor types which is converted to memref types.
|
||||
class BufferAssignmentTypeConverter : public TypeConverter {
|
||||
public:
|
||||
BufferAssignmentTypeConverter();
|
||||
|
||||
/// A helper function to check if `type` has been converted from non-memref
|
||||
/// type to memref.
|
||||
static bool isConvertedMemref(Type type, Type before);
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
/// Converts the signature of the function based on whether the function is
|
||||
/// allowed to return memref typed results or not using
|
||||
/// `allowMemrefFunctionResults` parameter. If this option is false, then it
|
||||
/// adds an extra function argument as an output buffer for each function result
|
||||
/// which is going to be a memref type only after type conversion. The
|
||||
/// other function result types remain unchanged. If
|
||||
/// `allowMemrefFunctionResults` is true, the types are converted in place.
|
||||
/// Any changes in function signature need to be applied
|
||||
/// to return and caller operations. `BufferAssignmentReturnOpConverter` and
|
||||
/// `BufferAssignmentCallOpConverter` are two helper function that match the
|
||||
/// return and caller operation with the new function signature. Furthermore,
|
||||
/// `BufferAssignmentTypeConverter` is a helper `TypeConverter` for converting
|
||||
/// tensor typed values to memref typed ones.
|
||||
template <bool allowMemrefFunctionResults>
|
||||
class BufferAssignmentFuncOpConverter
|
||||
: public BufferAssignmentOpConversionPattern<FuncOp> {
|
||||
public:
|
||||
|
@ -196,16 +112,58 @@ public:
|
|||
FuncOp>::BufferAssignmentOpConversionPattern;
|
||||
|
||||
/// Performs the actual signature rewriting step.
|
||||
LogicalResult matchAndRewrite(mlir::FuncOp, ArrayRef<Value>,
|
||||
ConversionPatternRewriter &) const;
|
||||
LogicalResult
|
||||
matchAndRewrite(mlir::FuncOp funcOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
if (!converter)
|
||||
return funcOp.emitError("The type converter has not been defined for "
|
||||
"BufferAssignmentFuncOpConverter");
|
||||
auto funcType = funcOp.getType();
|
||||
|
||||
// Convert function arguments using the provided TypeConverter.
|
||||
TypeConverter::SignatureConversion conversion(funcType.getNumInputs());
|
||||
for (auto argType : llvm::enumerate(funcType.getInputs()))
|
||||
conversion.addInputs(argType.index(),
|
||||
converter->convertType(argType.value()));
|
||||
|
||||
// If allowMemrefFunctionResults is false and a function result type is not
|
||||
// a memref but it would be a memref after type conversion, a new argument
|
||||
// should be appended to the function arguments list for this result.
|
||||
// Otherwise, it remains unchanged as a function result.
|
||||
SmallVector<Type, 2> newResultTypes;
|
||||
newResultTypes.reserve(funcOp.getNumResults());
|
||||
for (Type resType : funcType.getResults()) {
|
||||
Type convertedType = converter->convertType(resType);
|
||||
if (!allowMemrefFunctionResults &&
|
||||
BufferAssignmentTypeConverter::isConvertedMemref(convertedType,
|
||||
resType))
|
||||
conversion.addInputs(convertedType);
|
||||
else
|
||||
newResultTypes.push_back(convertedType);
|
||||
}
|
||||
if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), *converter,
|
||||
&conversion)))
|
||||
return failure();
|
||||
|
||||
// Update the signature of the function.
|
||||
rewriter.updateRootInPlace(funcOp, [&] {
|
||||
funcOp.setType(rewriter.getFunctionType(conversion.getConvertedTypes(),
|
||||
newResultTypes));
|
||||
});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Rewrites the `ReturnOp` to conform with the changed function signature.
|
||||
/// Operands that correspond to return values and their types have been set to
|
||||
/// AppendToArgumentsList are dropped. In their place, a corresponding copy
|
||||
/// operation from the operand to the target function argument is inserted.
|
||||
/// if allowMemrefFunctionResults is false, operands that correspond to return
|
||||
/// values and have been rewritten from illegal typed results to memref
|
||||
/// arguments are dropped. In their place, a corresponding copy operation from
|
||||
/// the operand to the output function argument is inserted. Otherwise, the
|
||||
/// memref typed operands are returned.
|
||||
/// Note: If this pattern rewriter is used with BufferAssignmentFuncOpConverter,
|
||||
/// allowMemrefFunctionResults must be set/unset for both.
|
||||
template <typename ReturnOpSourceTy, typename ReturnOpTargetTy,
|
||||
typename CopyOpTy>
|
||||
typename CopyOpTy, bool allowMemrefFunctionResults>
|
||||
class BufferAssignmentReturnOpConverter
|
||||
: public BufferAssignmentOpConversionPattern<ReturnOpSourceTy> {
|
||||
public:
|
||||
|
@ -216,48 +174,44 @@ public:
|
|||
LogicalResult
|
||||
matchAndRewrite(ReturnOpSourceTy returnOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
Location loc = returnOp.getLoc();
|
||||
|
||||
// Split the operands depending on whether they need a copy operation or
|
||||
// they remain as operands of the return operation. If an operand is
|
||||
// decomposable and a decompose callback function has been provided by the
|
||||
// user, it will be unpacked.
|
||||
SmallVector<Value, 2> newOperands, needCopyOperands;
|
||||
OpBuilder builder(returnOp);
|
||||
for (auto operand : llvm::enumerate(operands)) {
|
||||
SmallVector<Value, 2> values;
|
||||
this->converter->tryDecomposeValue(
|
||||
builder, loc, operand.value().getType(), operand.value(), values);
|
||||
Type type = returnOp.getOperand(operand.index()).getType();
|
||||
SmallVector<Type, 2> originTypes;
|
||||
this->converter->tryDecomposeType(type, originTypes);
|
||||
for (auto value : llvm::enumerate(values)) {
|
||||
Type origin = originTypes[value.index()];
|
||||
Type converted = value.value().getType();
|
||||
auto kind = this->converter->getResultConversionKind(origin, converted);
|
||||
if (kind == BufferAssignmentTypeConverter::KeepAsFunctionResult)
|
||||
newOperands.push_back(value.value());
|
||||
else
|
||||
// kind = BufferAssignmentTypeConverter::AppendToArgumentsList
|
||||
needCopyOperands.push_back(value.value());
|
||||
}
|
||||
// If the memref typed results can be returned as function results, the new
|
||||
// `ReturnOp` should only return the type converted operands.
|
||||
if (allowMemrefFunctionResults) {
|
||||
rewriter.replaceOpWithNewOp<ReturnOpTargetTy>(returnOp, operands);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Insert Copy operations instead for the operands that have been removed
|
||||
// from operand list and appended to the function arguments list.
|
||||
// Split the operands by their kinds whether they are converted memref or
|
||||
// not.
|
||||
SmallVector<Value, 2> needCopyOperands, newOperands;
|
||||
unsigned operandsSize = operands.size();
|
||||
needCopyOperands.reserve(operandsSize);
|
||||
newOperands.reserve(operandsSize);
|
||||
for (auto operand : llvm::enumerate(operands))
|
||||
if (BufferAssignmentTypeConverter::isConvertedMemref(
|
||||
operand.value().getType(),
|
||||
returnOp.getOperand(operand.index()).getType()))
|
||||
needCopyOperands.push_back(operand.value());
|
||||
else
|
||||
newOperands.push_back(operand.value());
|
||||
|
||||
Block &entryBlock = returnOp.getParentRegion()->front();
|
||||
unsigned numFuncArgs = entryBlock.getNumArguments();
|
||||
if (needCopyOperands.size() > numFuncArgs)
|
||||
return returnOp.emitError(
|
||||
"The number of operands that need Copy operations is more "
|
||||
"than the number of target function arguments.");
|
||||
|
||||
// Find the index of the first destination buffer.
|
||||
assert(needCopyOperands.size() <= numFuncArgs &&
|
||||
"The number of operands of return operation is more than the "
|
||||
"number of function arguments.");
|
||||
unsigned destArgNum = numFuncArgs - needCopyOperands.size();
|
||||
rewriter.setInsertionPoint(returnOp);
|
||||
for (Value operand : needCopyOperands) {
|
||||
rewriter.create<CopyOpTy>(loc, operand,
|
||||
// Insert a `CopyOp` for each converted memref-type operand.
|
||||
rewriter.create<CopyOpTy>(returnOp.getLoc(), operand,
|
||||
entryBlock.getArgument(destArgNum));
|
||||
++destArgNum;
|
||||
}
|
||||
|
||||
// Insert the new target Return operation.
|
||||
rewriter.replaceOpWithNewOp<ReturnOpTargetTy>(returnOp, newOperands);
|
||||
return success();
|
||||
}
|
||||
|
@ -265,32 +219,94 @@ public:
|
|||
|
||||
/// Rewrites the `CallOp` to match its operands and results with the signature
|
||||
/// of the callee after rewriting the callee with
|
||||
/// BufferAssignmentFuncOpConverter.
|
||||
/// BufferAssignmentFuncOpConverter. If allowMemrefFunctionResults is false, a
|
||||
/// buffer is allocated as an output buffer only for each memref typed result
|
||||
/// that has been rewritten. The new allocated buffer is passed through the
|
||||
/// operands list of the new `CallOp`.
|
||||
/// Note: If this pattern rewriter is used with BufferAssignmentFuncOpConverter,
|
||||
/// allowMemrefFunctionResults must be set/unset for both.
|
||||
template <bool allowMemrefFunctionResults>
|
||||
class BufferAssignmentCallOpConverter
|
||||
: public BufferAssignmentOpConversionPattern<CallOp> {
|
||||
public:
|
||||
using BufferAssignmentOpConversionPattern<
|
||||
CallOp>::BufferAssignmentOpConversionPattern;
|
||||
|
||||
/// Performs the actual rewriting step.
|
||||
LogicalResult matchAndRewrite(CallOp, ArrayRef<Value>,
|
||||
ConversionPatternRewriter &) const;
|
||||
LogicalResult
|
||||
matchAndRewrite(CallOp callOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
if (!converter)
|
||||
return callOp.emitError("The type converter has not been defined for "
|
||||
"BufferAssignmentCallOpConverter");
|
||||
Location loc = callOp.getLoc();
|
||||
|
||||
// If the memref typed results can be returned as function results, there is
|
||||
// no need to create output buffers. It is only required to convert the type
|
||||
// of operands and results in place for creating the new `CallOp`.
|
||||
if (allowMemrefFunctionResults) {
|
||||
SmallVector<Type, 2> resultTypes;
|
||||
resultTypes.reserve(callOp.getNumResults());
|
||||
for (Type type : callOp.getResultTypes())
|
||||
resultTypes.push_back(converter->convertType(type));
|
||||
rewriter.replaceOpWithNewOp<CallOp>(callOp, callOp.getCallee(),
|
||||
resultTypes, operands);
|
||||
return success();
|
||||
}
|
||||
|
||||
SmallVector<Value, 2> newOperands, replacingValues;
|
||||
SmallVector<Type, 2> newResultTypes;
|
||||
unsigned numResults = callOp.getNumResults();
|
||||
newOperands.reserve(numResults + operands.size());
|
||||
newOperands.append(operands.begin(), operands.end());
|
||||
newResultTypes.reserve(numResults);
|
||||
replacingValues.reserve(numResults);
|
||||
|
||||
// For each memref result of `CallOp` which has not been a memref before
|
||||
// the type conversion, a new buffer is allocated and passed to the operands
|
||||
// list of the new `CallOp`. Otherwise, it remains as a caller result.
|
||||
for (Value result : callOp.getResults()) {
|
||||
Type currType = result.getType();
|
||||
Type newType = converter->convertType(result.getType());
|
||||
if (BufferAssignmentTypeConverter::isConvertedMemref(newType, currType)) {
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.restoreInsertionPoint(bufferAssignment->computeAllocPosition(
|
||||
result.dyn_cast<OpResult>()));
|
||||
Value alloc =
|
||||
rewriter.create<AllocOp>(loc, newType.dyn_cast<MemRefType>());
|
||||
newOperands.push_back(alloc);
|
||||
replacingValues.push_back(alloc);
|
||||
} else {
|
||||
newResultTypes.push_back(currType);
|
||||
|
||||
// No replacing is required.
|
||||
replacingValues.push_back(nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
// Creating the new `CallOp`.
|
||||
rewriter.create<CallOp>(loc, callOp.getCallee(), newResultTypes,
|
||||
newOperands);
|
||||
|
||||
// Replacing the results of the old `CallOp`.
|
||||
rewriter.replaceOp(callOp, replacingValues);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // end namespace detail
|
||||
|
||||
/// Populates `patterns` with the conversion patterns of buffer
|
||||
/// assignment.
|
||||
template <typename ReturnOpSourceTy, typename ReturnOpTargetTy,
|
||||
typename CopyOpTy>
|
||||
typename CopyOpTy, bool allowMemrefFunctionResults>
|
||||
static void populateWithBufferAssignmentOpConversionPatterns(
|
||||
MLIRContext *context, BufferAssignmentPlacer *placer,
|
||||
BufferAssignmentTypeConverter *converter,
|
||||
OwningRewritePatternList *patterns) {
|
||||
TypeConverter *converter, OwningRewritePatternList *patterns) {
|
||||
// clang-format off
|
||||
patterns->insert<
|
||||
BufferAssignmentCallOpConverter,
|
||||
BufferAssignmentFuncOpConverter,
|
||||
BufferAssignmentReturnOpConverter
|
||||
<ReturnOpSourceTy, ReturnOpTargetTy, CopyOpTy>
|
||||
detail::BufferAssignmentCallOpConverter<allowMemrefFunctionResults>,
|
||||
detail::BufferAssignmentFuncOpConverter<allowMemrefFunctionResults>,
|
||||
detail::BufferAssignmentReturnOpConverter
|
||||
<ReturnOpSourceTy, ReturnOpTargetTy, CopyOpTy, allowMemrefFunctionResults>
|
||||
>(context, placer, converter);
|
||||
// clang-format on
|
||||
}
|
||||
|
|
|
@ -100,11 +100,11 @@ public:
|
|||
/// tensors to buffers.
|
||||
static void populateConvertLinalgOnTensorsToBuffersPattern(
|
||||
MLIRContext *context, BufferAssignmentPlacer *placer,
|
||||
BufferAssignmentTypeConverter *converter,
|
||||
OwningRewritePatternList *patterns) {
|
||||
TypeConverter *converter, OwningRewritePatternList *patterns) {
|
||||
populateWithBufferAssignmentOpConversionPatterns<
|
||||
mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp>(context, placer,
|
||||
converter, patterns);
|
||||
mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp,
|
||||
/*allowMemrefFunctionResults=*/false>(context, placer, converter,
|
||||
patterns);
|
||||
patterns->insert<GenericOpConverter>(context, placer, converter);
|
||||
}
|
||||
|
||||
|
@ -141,9 +141,6 @@ struct ConvertLinalgOnTensorsToBuffers
|
|||
converter.isLegal(&funcOp.getBody());
|
||||
});
|
||||
|
||||
converter.setResultConversionKind<RankedTensorType, MemRefType>(
|
||||
BufferAssignmentTypeConverter::AppendToArgumentsList);
|
||||
|
||||
// Walk over all the functions to apply buffer assignment.
|
||||
getOperation().walk([&](FuncOp function) -> WalkResult {
|
||||
OwningRewritePatternList patterns;
|
||||
|
|
|
@ -713,223 +713,9 @@ BufferAssignmentTypeConverter::BufferAssignmentTypeConverter() {
|
|||
});
|
||||
}
|
||||
|
||||
/// This method tries to decompose a value of a certain type using provided
|
||||
/// decompose callback functions. If it is unable to do so, the original value
|
||||
/// is returned.
|
||||
void BufferAssignmentTypeConverter::tryDecomposeValue(
|
||||
OpBuilder &builder, Location loc, Type type, Value value,
|
||||
SmallVectorImpl<Value> &results) {
|
||||
for (auto conversion : decomposeValueConversions)
|
||||
if (conversion(builder, loc, type, value, results) != llvm::None)
|
||||
return;
|
||||
results.push_back(value);
|
||||
}
|
||||
|
||||
/// This method tries to decompose a type using provided decompose callback
|
||||
/// functions. If it is unable to do so, the original type is returned.
|
||||
void BufferAssignmentTypeConverter::tryDecomposeType(
|
||||
Type type, SmallVectorImpl<Type> &types) {
|
||||
for (auto conversion : decomposeTypeConversions)
|
||||
if (conversion(type, types) != llvm::None)
|
||||
return;
|
||||
types.push_back(type);
|
||||
}
|
||||
|
||||
/// This method returns ResultConversionKind for the input type.
|
||||
BufferAssignmentTypeConverter::ResultConversionKind
|
||||
BufferAssignmentTypeConverter::getResultConversionKind(Type origin,
|
||||
Type converted) {
|
||||
for (auto conversion : resultTypeConversions) {
|
||||
auto res = conversion(origin, converted);
|
||||
if (res != llvm::None)
|
||||
return res.getValue();
|
||||
}
|
||||
return KeepAsFunctionResult;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BufferAssignmentFuncOpConverter
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Performs the actual function signature rewriting step.
|
||||
LogicalResult BufferAssignmentFuncOpConverter::matchAndRewrite(
|
||||
mlir::FuncOp funcOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto funcType = funcOp.getType();
|
||||
|
||||
// Convert function arguments using the provided TypeConverter.
|
||||
TypeConverter::SignatureConversion conversion(funcType.getNumInputs());
|
||||
for (auto argType : llvm::enumerate(funcType.getInputs())) {
|
||||
SmallVector<Type, 2> decomposedTypes, convertedTypes;
|
||||
converter->tryDecomposeType(argType.value(), decomposedTypes);
|
||||
converter->convertTypes(decomposedTypes, convertedTypes);
|
||||
conversion.addInputs(argType.index(), convertedTypes);
|
||||
}
|
||||
|
||||
// Convert the result types of the function.
|
||||
SmallVector<Type, 2> newResultTypes;
|
||||
newResultTypes.reserve(funcOp.getNumResults());
|
||||
for (Type resultType : funcType.getResults()) {
|
||||
SmallVector<Type, 2> originTypes;
|
||||
converter->tryDecomposeType(resultType, originTypes);
|
||||
for (auto origin : originTypes) {
|
||||
Type converted = converter->convertType(origin);
|
||||
auto kind = converter->getResultConversionKind(origin, converted);
|
||||
if (kind == BufferAssignmentTypeConverter::AppendToArgumentsList)
|
||||
conversion.addInputs(converted);
|
||||
else
|
||||
// kind = BufferAssignmentTypeConverter::KeepAsFunctionResult
|
||||
newResultTypes.push_back(converted);
|
||||
}
|
||||
}
|
||||
|
||||
if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), *converter,
|
||||
&conversion)))
|
||||
return failure();
|
||||
|
||||
// Update the signature of the function.
|
||||
rewriter.updateRootInPlace(funcOp, [&] {
|
||||
funcOp.setType(rewriter.getFunctionType(conversion.getConvertedTypes(),
|
||||
newResultTypes));
|
||||
});
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BufferAssignmentCallOpConverter
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Performs the actual rewriting step.
|
||||
LogicalResult BufferAssignmentCallOpConverter::matchAndRewrite(
|
||||
CallOp callOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
// This class represents a mapping from a result to a list of values and some
|
||||
// results that have not yet constructed. Instead, the indices of these
|
||||
// results in the operation that will be constructed are known. They will be
|
||||
// replaced with the actual values when they are available. The order of
|
||||
// adding to this mapping is important.
|
||||
class ResultMapping {
|
||||
public:
|
||||
ResultMapping() { order = 0; };
|
||||
|
||||
/// Add an available value to the mapping.
|
||||
void addMapping(Value value) {
|
||||
toValuesMapping.push_back({order++, value});
|
||||
}
|
||||
|
||||
/// Add the index of unavailble result value to the mapping.
|
||||
void addMapping(unsigned index) {
|
||||
toIndicesMapping.push_back({order++, index});
|
||||
}
|
||||
|
||||
/// This method returns the mapping values list. The unknown result values
|
||||
/// that only their indicies are available are replaced with their values.
|
||||
void getMappingValues(ValueRange valuesToReplaceIndices,
|
||||
SmallVectorImpl<Value> &values) {
|
||||
// Append available values to the list.
|
||||
SmallVector<std::pair<unsigned, Value>, 2> res(toValuesMapping.begin(),
|
||||
toValuesMapping.end());
|
||||
// Replace the indices with the actual values.
|
||||
llvm::for_each(
|
||||
toIndicesMapping, [&](const std::pair<unsigned, unsigned> &entry) {
|
||||
assert(entry.second < valuesToReplaceIndices.size() &&
|
||||
"The value index is out of range.");
|
||||
res.push_back({entry.first, valuesToReplaceIndices[entry.second]});
|
||||
});
|
||||
// Sort the values based on their adding orders.
|
||||
llvm::sort(res, [](const std::pair<unsigned, Value> &v1,
|
||||
const std::pair<unsigned, Value> &v2) {
|
||||
return v1.first < v2.first;
|
||||
});
|
||||
// Fill the values.
|
||||
llvm::for_each(res, [&](const std::pair<unsigned, Value> &entry) {
|
||||
values.push_back(entry.second);
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
/// Keeping the inserting order of mapping values.
|
||||
int order;
|
||||
|
||||
/// Containing the mapping values with their inserting orders.
|
||||
SmallVector<std::pair<unsigned, Value>, 2> toValuesMapping;
|
||||
|
||||
/// Containing the indices of result values with their inserting orders.
|
||||
SmallVector<std::pair<unsigned, unsigned>, 2> toIndicesMapping;
|
||||
};
|
||||
|
||||
Location loc = callOp.getLoc();
|
||||
OpBuilder builder(callOp);
|
||||
SmallVector<Value, 2> newOperands;
|
||||
|
||||
// Create the operands list of the new `CallOp`. It unpacks the decomposable
|
||||
// values if a decompose callback function has been provided by the user.
|
||||
for (auto operand : operands) {
|
||||
SmallVector<Value, 2> values;
|
||||
this->converter->tryDecomposeValue(builder, loc, operand.getType(), operand,
|
||||
values);
|
||||
newOperands.append(values.begin(), values.end());
|
||||
}
|
||||
|
||||
// Create the new result types for the new `CallOp` and a mapping from the old
|
||||
// result to new value(s).
|
||||
SmallVector<Type, 2> newResultTypes;
|
||||
SmallVector<ResultMapping, 4> mappings;
|
||||
mappings.resize(callOp.getNumResults());
|
||||
for (auto result : llvm::enumerate(callOp.getResults())) {
|
||||
SmallVector<Type, 2> originTypes;
|
||||
converter->tryDecomposeType(result.value().getType(), originTypes);
|
||||
auto &resultMapping = mappings[result.index()];
|
||||
for (Type origin : originTypes) {
|
||||
Type converted = converter->convertType(origin);
|
||||
auto kind = converter->getResultConversionKind(origin, converted);
|
||||
if (kind == BufferAssignmentTypeConverter::KeepAsFunctionResult) {
|
||||
newResultTypes.push_back(converted);
|
||||
// The result value is not yet available. Its index is kept and it is
|
||||
// replaced with the actual value of the new `CallOp` later.
|
||||
resultMapping.addMapping(newResultTypes.size() - 1);
|
||||
} else {
|
||||
// kind = BufferAssignmentTypeConverter::AppendToArgumentsList
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.restoreInsertionPoint(
|
||||
bufferAssignment->computeAllocPosition(result.value()));
|
||||
MemRefType memref = converted.dyn_cast<MemRefType>();
|
||||
if (!memref)
|
||||
return callOp.emitError("Cannot allocate for a non-Memref type");
|
||||
Value alloc = rewriter.create<AllocOp>(loc, memref);
|
||||
newOperands.push_back(alloc);
|
||||
resultMapping.addMapping(alloc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CallOp newCallOp = rewriter.create<CallOp>(loc, callOp.getCallee(),
|
||||
newResultTypes, newOperands);
|
||||
|
||||
// Build a replacing value for each result to replace its uses. If a result
|
||||
// has multiple mapping values, it needs to be packed to a single value.
|
||||
OpBuilder nextBuilder(callOp.getOperation()->getNextNode());
|
||||
SmallVector<Value, 2> replacedValues;
|
||||
replacedValues.reserve(callOp.getNumResults());
|
||||
for (unsigned i = 0, e = callOp.getNumResults(); i < e; ++i) {
|
||||
SmallVector<Value, 2> valuesToPack;
|
||||
mappings[i].getMappingValues(newCallOp.getResults(), valuesToPack);
|
||||
if (valuesToPack.empty()) {
|
||||
// No replacement is required.
|
||||
replacedValues.push_back(nullptr);
|
||||
} else if (valuesToPack.size() == 1) {
|
||||
replacedValues.push_back(valuesToPack.front());
|
||||
} else {
|
||||
// Values need to be packed using callback function. The same callback
|
||||
// that is used for materializeArgumentConversion is used for packing.
|
||||
Value packed = converter->materializeArgumentConversion(
|
||||
nextBuilder, loc, callOp.getType(i), valuesToPack);
|
||||
replacedValues.push_back(packed);
|
||||
}
|
||||
}
|
||||
rewriter.replaceOp(callOp, replacedValues);
|
||||
return success();
|
||||
/// Checks if `type` has been converted from non-memref type to memref.
|
||||
bool BufferAssignmentTypeConverter::isConvertedMemref(Type type, Type before) {
|
||||
return type.isa<BaseMemRefType>() && !before.isa<BaseMemRefType>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -111,73 +111,7 @@ func @caller(%arg0: tensor<5xf32>) -> tensor<5xf32> {
|
|||
// CHECK: %[[Y:.*]]:2 = call @callee(%[[X]]#0)
|
||||
// CHECK: return %[[Y]]#0
|
||||
|
||||
// -----
|
||||
|
||||
// Test case: Testing BufferAssginmnetCallOpConverter to see if it matches with the
|
||||
// signature of the new signature of the callee function when there are tuple typed
|
||||
// args and results. BufferAssginmentTypeConverter is set to flatten tuple typed
|
||||
// arguments. The tuple typed values should be decomposed and composed using
|
||||
// get_tuple_element and make_tuple operations of test dialect. Tensor types are
|
||||
// converted to Memref. Memref typed function results remain as function results.
|
||||
|
||||
// CHECK-LABEL: func @callee
|
||||
func @callee(%arg0: tuple<tensor<2xf32>,i1, tensor<5xf32>>) -> (tuple<tensor<2xf32>,i1, tensor<5xf32>>){
|
||||
return %arg0 : tuple<tensor<2xf32>,i1, tensor<5xf32>>
|
||||
}
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: memref<2xf32>, %[[ARG1:.*]]: i1, %[[ARG2:.*]]: memref<5xf32>)
|
||||
// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>)
|
||||
// CHECK-NEXT: %[[TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]], %[[ARG2]])
|
||||
// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 0 : i32}
|
||||
// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 1 : i32}
|
||||
// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 2 : i32}
|
||||
// CHECK-NEXT: return %[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]]
|
||||
|
||||
// CHECK-LABEL: func @caller
|
||||
func @caller(%arg0: tuple<tensor<2xf32>,i1, tensor<5xf32>>) -> tuple<tensor<2xf32>,i1, tensor<5xf32>>{
|
||||
%x0 = call @callee(%arg0) : (tuple<tensor<2xf32>,i1, tensor<5xf32>>) -> (tuple<tensor<2xf32>,i1, tensor<5xf32>>)
|
||||
%y0 = call @callee(%x0) : (tuple<tensor<2xf32>,i1, tensor<5xf32>>) -> (tuple<tensor<2xf32>,i1, tensor<5xf32>>)
|
||||
return %y0 : tuple<tensor<2xf32>,i1, tensor<5xf32>>
|
||||
}
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: memref<2xf32>, %[[ARG1:.*]]: i1, %[[ARG2:.*]]: memref<5xf32>)
|
||||
// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>)
|
||||
// CHECK-NEXT: %[[ARG_TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]], %[[ARG2]])
|
||||
// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) {index = 0 : i32}
|
||||
// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) {index = 1 : i32}
|
||||
// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) {index = 2 : i32}
|
||||
// CHECK-NEXT: %[[CALLEE_RESULTS:.*]]:3 = call @callee(%[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]])
|
||||
// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>) -> (memref<2xf32>, i1, memref<5xf32>)
|
||||
// CHECK-NEXT: %[[RESULT_TUPLE:.*]] = "test.make_tuple"(%[[CALLEE_RESULTS]]#0, %[[CALLEE_RESULTS]]#1, %[[CALLEE_RESULTS]]#2)
|
||||
// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[RESULT_TUPLE]]) {index = 0 : i32}
|
||||
// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[RESULT_TUPLE]]) {index = 1 : i32}
|
||||
// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[RESULT_TUPLE]]) {index = 2 : i32}
|
||||
// CHECK-NEXT: %[[CALLEE_RESULTS:.*]]:3 = call @callee(%[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]])
|
||||
// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>) -> (memref<2xf32>, i1, memref<5xf32>)
|
||||
// CHECK-NEXT: %[[RETURN_TUPLE:.*]] = "test.make_tuple"(%[[CALLEE_RESULTS]]#0, %[[CALLEE_RESULTS]]#1, %[[CALLEE_RESULTS]]#2)
|
||||
// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[RETURN_TUPLE]]) {index = 0 : i32}
|
||||
// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[RETURN_TUPLE]]) {index = 1 : i32}
|
||||
// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[RETURN_TUPLE]]) {index = 2 : i32}
|
||||
// CHECK-NEXT: return %[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]]
|
||||
|
||||
// -----
|
||||
|
||||
// Test case: Testing BufferAssginmnetFuncOpConverter and
|
||||
// BufferAssginmentReturnOpConverter to see if the return operation matches with
|
||||
// the new function signature when there are tuple typed args and results.
|
||||
// BufferAssginmentTypeConverter is set to flatten tuple typed arguments. The tuple
|
||||
// typed values should be decomposed and composed using get_tuple_element and
|
||||
// make_tuple operations of test dialect. Tensor types are converted to Memref.
|
||||
// Memref typed function results remain as function results.
|
||||
|
||||
// CHECK-LABEL: func @decompose_tuple_typed_function_args_and_results
|
||||
func @decompose_tuple_typed_function_args_and_results(%arg0: tuple<i1,f32>, %arg1: tensor<10xf32>, %arg2: tuple<i1, tensor<5xf32>>) -> (tuple<i1, tensor<5xf32>>, tensor<10xf32>, tuple<i1,f32>){
|
||||
return %arg2, %arg1, %arg0 : tuple<i1, tensor<5xf32>>, tensor<10xf32>, tuple<i1,f32>
|
||||
}
|
||||
// CHECK-SAME: %[[ARG0:.*]]: i1, %[[ARG1:.*]]: f32, %[[ARG2:.*]]: memref<10xf32>, %[[ARG3:.*]]: i1, %[[ARG4:.*]]: memref<5xf32>
|
||||
// CHECK-SAME: (i1, memref<5xf32>, memref<10xf32>, i1, f32)
|
||||
// CHECK-NEXT: %[[FIRST_TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]])
|
||||
// CHECK-NEXT: %[[SECOND_TUPLE:.*]] = "test.make_tuple"(%[[ARG3]], %[[ARG4]])
|
||||
// CHECK-NEXT: %[[SECOND_TUPLE_FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[SECOND_TUPLE]]) {index = 0 : i32}
|
||||
// CHECK-NEXT: %[[SECOND_TUPLE_SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[SECOND_TUPLE]]) {index = 1 : i32}
|
||||
// CHECK-NEXT: %[[FIRST_TUPLE_FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[FIRST_TUPLE]]) {index = 0 : i32}
|
||||
// CHECK-NEXT: %[[FIRST_TUPLE_SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[FIRST_TUPLE]]) {index = 1 : i32}
|
||||
// CHECK-NEXT: return %[[SECOND_TUPLE_FIRST_ELEM]], %[[SECOND_TUPLE_SECOND_ELEM]], %[[ARG2]], %[[FIRST_TUPLE_FIRST_ELEM]], %[[FIRST_TUPLE_SECOND_ELEM]]
|
||||
|
|
|
@ -285,93 +285,8 @@ func @caller(%arg0: tensor<5xf32>) -> tensor<5xf32> {
|
|||
// CHECK: linalg.copy(%[[Y0]], %[[CALLER_RESULT]])
|
||||
// CHECK: return
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @func_with_unranked_arg
|
||||
func @func_with_unranked_arg(%arg0: tensor<*xf32>) {
|
||||
return
|
||||
}
|
||||
// CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>)
|
||||
|
||||
// -----
|
||||
|
||||
// Test case: Testing BufferAssginmnetCallOpConverter to see if it matches with the
|
||||
// signature of the new signature of the callee function when there are tuple typed
|
||||
// args and results. BufferAssginmentTypeConverter is set to flatten tuple typed
|
||||
// arguments. The tuple typed values should be decomposed and composed using
|
||||
// get_tuple_element and make_tuple operations of test dialect. Tensor types are
|
||||
// converted to Memref. Memref typed function results are appended to the function
|
||||
// arguments list.
|
||||
|
||||
// CHECK-LABEL: func @callee
|
||||
func @callee(%arg0: tuple<tensor<2xf32>,i1, tensor<5xf32>>) -> (tuple<tensor<2xf32>,i1, tensor<5xf32>>){
|
||||
return %arg0 : tuple<tensor<2xf32>,i1, tensor<5xf32>>
|
||||
}
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: memref<2xf32>, %[[ARG1:.*]]: i1, %[[ARG2:.*]]: memref<5xf32>, %[[RESULT0:.*]]: memref<2xf32>, %[[RESULT1:.*]]: memref<5xf32>)
|
||||
// CHECK-SAME: i1
|
||||
// CHECK-NEXT: %[[TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]], %[[ARG2]])
|
||||
// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 0 : i32}
|
||||
// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 1 : i32}
|
||||
// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 2 : i32}
|
||||
// CHECK-NEXT: linalg.copy(%[[FIRST_ELEM]], %[[RESULT0]])
|
||||
// CHECK-NEXT: linalg.copy(%[[THIRD_ELEM]], %[[RESULT1]])
|
||||
// CHECK-NEXT: return %[[SECOND_ELEM]]
|
||||
|
||||
|
||||
// CHECK-LABEL: func @caller
|
||||
func @caller(%arg0: tuple<tensor<2xf32>,i1, tensor<5xf32>>) -> tuple<tensor<2xf32>,i1, tensor<5xf32>>{
|
||||
%x0 = call @callee(%arg0) : (tuple<tensor<2xf32>,i1, tensor<5xf32>>) -> (tuple<tensor<2xf32>,i1, tensor<5xf32>>)
|
||||
%y0 = call @callee(%x0) : (tuple<tensor<2xf32>,i1, tensor<5xf32>>) -> (tuple<tensor<2xf32>,i1, tensor<5xf32>>)
|
||||
return %y0 : tuple<tensor<2xf32>,i1, tensor<5xf32>>
|
||||
}
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: memref<2xf32>, %[[ARG1:.*]]: i1, %[[ARG2:.*]]: memref<5xf32>, %[[RESULT0:.*]]: memref<2xf32>, %[[RESULT1:.*]]: memref<5xf32>)
|
||||
// CHECK-SAME: i1
|
||||
// CHECK-NEXT: %[[TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]], %[[ARG2]])
|
||||
// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 0 : i32}
|
||||
// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 1 : i32}
|
||||
// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 2 : i32}
|
||||
// CHECK-NEXT: %[[FIRST_ALLOC:.*]] = alloc()
|
||||
// CHECK-NEXT: %[[SECOND_ALLOC:.*]] = alloc()
|
||||
// CHECK-NEXT: %[[CALLEE_RESULT:.*]] = call @callee(%[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]], %[[FIRST_ALLOC]], %[[SECOND_ALLOC]])
|
||||
// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>, memref<2xf32>, memref<5xf32>) -> i1
|
||||
// CHECK-NEXT: %[[TUPLE:.*]] = "test.make_tuple"(%[[FIRST_ALLOC]], %[[CALLEE_RESULT]], %[[SECOND_ALLOC]])
|
||||
// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 0 : i32}
|
||||
// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 1 : i32}
|
||||
// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 2 : i32}
|
||||
// CHECK-NEXT: %[[FIRST_ALLOC:.*]] = alloc()
|
||||
// CHECK-NEXT: %[[SECOND_ALLOC:.*]] = alloc()
|
||||
// CHECK-NEXT: %[[CALLEE_RESULT:.*]] = call @callee(%[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]], %[[FIRST_ALLOC]], %[[SECOND_ALLOC]])
|
||||
// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>, memref<2xf32>, memref<5xf32>) -> i1
|
||||
// CHECK-NEXT: %[[TUPLE:.*]] = "test.make_tuple"(%[[FIRST_ALLOC]], %[[CALLEE_RESULT]], %[[SECOND_ALLOC]])
|
||||
// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 0 : i32}
|
||||
// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 1 : i32}
|
||||
// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 2 : i32}
|
||||
// CHECK-NEXT: linalg.copy(%[[FIRST_ELEM]], %[[RESULT0]])
|
||||
// CHECK-NEXT: linalg.copy(%[[THIRD_ELEM]], %[[RESULT1]])
|
||||
// CHECK-NEXT: return %[[SECOND_ELEM]]
|
||||
|
||||
// -----
|
||||
|
||||
// Test case: Testing BufferAssginmnetFuncOpConverter and
|
||||
// BufferAssginmentReturnOpConverter to see if the return operation matches with
|
||||
// the new function signature when there are tuple typed args and results.
|
||||
// BufferAssginmentTypeConverter is set to flatten tuple typed arguments. The tuple
|
||||
// typed values should be decomposed and composed using get_tuple_element and
|
||||
// make_tuple operations of test dialect. Tensor types are converted to Memref.
|
||||
// Memref typed function results are appended to the function arguments list.
|
||||
|
||||
// CHECK-LABEL: func @decompose_tuple_typed_function_args_and_results
|
||||
func @decompose_tuple_typed_function_args_and_results(%arg0: tuple<i1,f32>, %arg1: tensor<10xf32>, %arg2: tuple<i1, tensor<5xf32>>) -> (tuple<i1, tensor<5xf32>>, tensor<10xf32>, tuple<i1,f32>){
|
||||
return %arg2, %arg1, %arg0 : tuple<i1, tensor<5xf32>>, tensor<10xf32>, tuple<i1,f32>
|
||||
}
|
||||
// CHECK-SAME: %[[ARG0:.*]]: i1, %[[ARG1:.*]]: f32, %[[ARG2:.*]]: memref<10xf32>, %[[ARG3:.*]]: i1, %[[ARG4:.*]]: memref<5xf32>, %[[RESULT0:.*]]: memref<5xf32>, %[[RESULT1:.*]]: memref<10xf32>
|
||||
// CHECK-SAME: (i1, i1, f32)
|
||||
// CHECK-NEXT: %[[FIRST_TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]])
|
||||
// CHECK-NEXT: %[[SECOND_TUPLE:.*]] = "test.make_tuple"(%[[ARG3]], %[[ARG4]])
|
||||
// CHECK-NEXT: %[[SECOND_TUPLE_FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[SECOND_TUPLE]]) {index = 0 : i32}
|
||||
// CHECK-NEXT: %[[SECOND_TUPLE_SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[SECOND_TUPLE]]) {index = 1 : i32}
|
||||
// CHECK-NEXT: %[[FIRST_TUPLE_FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[FIRST_TUPLE]]) {index = 0 : i32}
|
||||
// CHECK-NEXT: %[[FIRST_TUPLE_SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[FIRST_TUPLE]]) {index = 1 : i32}
|
||||
// CHECK-NEXT: linalg.copy(%[[SECOND_TUPLE_SECOND_ELEM]], %[[RESULT0]])
|
||||
// CHECK-NEXT: linalg.copy(%[[ARG2]], %[[RESULT1]])
|
||||
// CHECK-NEXT: return %[[SECOND_TUPLE_FIRST_ELEM]], %[[FIRST_TUPLE_FIRST_ELEM]], %[[FIRST_TUPLE_SECOND_ELEM]]
|
||||
|
|
|
@ -1669,7 +1669,7 @@ def TableGenBuildOp5 : TEST_Op<"tblgen_build_5",
|
|||
let results = (outs AnyType:$result);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
static LogicalResult inferReturnTypes(MLIRContext *,
|
||||
static LogicalResult inferReturnTypes(MLIRContext *,
|
||||
Optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
|
@ -1679,31 +1679,4 @@ def TableGenBuildOp5 : TEST_Op<"tblgen_build_5",
|
|||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test BufferPlacement
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def GetTupleElementOp: TEST_Op<"get_tuple_element"> {
|
||||
let description = [{
|
||||
Test op that returns a specified element of the tuple.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TupleOf<[AnyType]>,
|
||||
I32Attr:$index
|
||||
);
|
||||
let results = (outs AnyType);
|
||||
}
|
||||
|
||||
def MakeTupleOp: TEST_Op<"make_tuple"> {
|
||||
let description = [{
|
||||
Test op that creates a tuple value from a list of values.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<AnyType>:$inputs
|
||||
);
|
||||
let results = (outs TupleOf<[AnyType]>);
|
||||
}
|
||||
|
||||
#endif // TEST_OPS
|
||||
|
|
|
@ -11,8 +11,6 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "TestDialect.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
|
@ -111,16 +109,14 @@ struct TestBufferPlacementPreparationPass
|
|||
|
||||
void populateTensorLinalgToBufferLinalgConversionPattern(
|
||||
MLIRContext *context, BufferAssignmentPlacer *placer,
|
||||
BufferAssignmentTypeConverter *converter,
|
||||
OwningRewritePatternList *patterns) {
|
||||
TypeConverter *converter, OwningRewritePatternList *patterns) {
|
||||
populateWithBufferAssignmentOpConversionPatterns<
|
||||
mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp>(context, placer,
|
||||
converter, patterns);
|
||||
mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp,
|
||||
allowMemrefFunctionResults>(context, placer, converter, patterns);
|
||||
patterns->insert<GenericOpConverter>(context, placer, converter);
|
||||
}
|
||||
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<TestDialect>();
|
||||
registry.insert<linalg::LinalgDialect>();
|
||||
}
|
||||
|
||||
|
@ -131,8 +127,6 @@ struct TestBufferPlacementPreparationPass
|
|||
|
||||
// Mark all Standard operations legal.
|
||||
target.addLegalDialect<StandardOpsDialect>();
|
||||
target.addLegalOp<MakeTupleOp>();
|
||||
target.addLegalOp<GetTupleElementOp>();
|
||||
|
||||
// Mark all Linalg operations illegal as long as they work on tensors.
|
||||
auto isLegalOperation = [&](Operation *op) {
|
||||
|
@ -155,42 +149,6 @@ struct TestBufferPlacementPreparationPass
|
|||
converter.isLegal(&funcOp.getBody());
|
||||
});
|
||||
|
||||
auto kind = allowMemrefFunctionResults
|
||||
? BufferAssignmentTypeConverter::KeepAsFunctionResult
|
||||
: BufferAssignmentTypeConverter::AppendToArgumentsList;
|
||||
converter.setResultConversionKind<RankedTensorType, MemRefType>(kind);
|
||||
converter.setResultConversionKind<UnrankedTensorType, UnrankedMemRefType>(
|
||||
kind);
|
||||
|
||||
converter.addDecomposeTypeConversion(
|
||||
[](TupleType tupleType, SmallVectorImpl<Type> &types) {
|
||||
tupleType.getFlattenedTypes(types);
|
||||
return success();
|
||||
});
|
||||
|
||||
converter.addArgumentMaterialization(
|
||||
[](OpBuilder &builder, TupleType resultType, ValueRange inputs,
|
||||
Location loc) -> Optional<Value> {
|
||||
if (inputs.size() == 1)
|
||||
return llvm::None;
|
||||
TypeRange TypeRange = inputs.getTypes();
|
||||
SmallVector<Type, 2> types(TypeRange.begin(), TypeRange.end());
|
||||
TupleType tuple = TupleType::get(types, builder.getContext());
|
||||
mlir::Value value = builder.create<MakeTupleOp>(loc, tuple, inputs);
|
||||
return value;
|
||||
});
|
||||
|
||||
converter.addDecomposeValueConversion([](OpBuilder &builder, Location loc,
|
||||
TupleType resultType, Value value,
|
||||
SmallVectorImpl<Value> &values) {
|
||||
for (unsigned i = 0, e = resultType.size(); i < e; ++i) {
|
||||
Value res = builder.create<GetTupleElementOp>(
|
||||
loc, resultType.getType(i), value, builder.getI32IntegerAttr(i));
|
||||
values.push_back(res);
|
||||
}
|
||||
return success();
|
||||
});
|
||||
|
||||
// Walk over all the functions to apply buffer assignment.
|
||||
this->getOperation().walk([&](FuncOp function) -> WalkResult {
|
||||
OwningRewritePatternList patterns;
|
||||
|
|
Loading…
Reference in New Issue