forked from OSchip/llvm-project
Revert "[mlir] Extend BufferAssignmentTypeConverter with result conversion callbacks"
This reverts commit 94f5d24877
because
of failing the following tests:
MLIR :: Dialect/Linalg/tensors-to-buffers.mlir
MLIR :: Transforms/buffer-placement-preparation-allowed-memref-results.mlir
MLIR :: Transforms/buffer-placement-preparation.mlir
This commit is contained in:
parent
6d36b22b21
commit
1b88bbf5eb
|
@ -52,111 +52,6 @@ private:
|
||||||
Operation *operation;
|
Operation *operation;
|
||||||
};
|
};
|
||||||
|
|
||||||
/// A helper type converter class for using inside Buffer Assignment operation
|
|
||||||
/// conversion patterns. The default constructor keeps all the types intact
|
|
||||||
/// except for the ranked-tensor types which is converted to memref types.
|
|
||||||
class BufferAssignmentTypeConverter : public TypeConverter {
|
|
||||||
public:
|
|
||||||
/// This enum is for showing how buffer placement operation converters should
|
|
||||||
/// conduct with certain result type after type conversion. This value can be
|
|
||||||
/// set/get for each specific type using setResultConversionKind or
|
|
||||||
/// getResultConversionKind.
|
|
||||||
enum ResultConversionKind { AppendToArgumentsList, KeepAsFunctionResult };
|
|
||||||
|
|
||||||
BufferAssignmentTypeConverter();
|
|
||||||
|
|
||||||
/// This method tries to decompose a value of a certain type using provided
|
|
||||||
/// decompose callback functions. If it is unable to do so, the original value
|
|
||||||
/// is returned.
|
|
||||||
void tryDecomposeValue(OpBuilder &, Location, Type, Value,
|
|
||||||
SmallVectorImpl<Value> &);
|
|
||||||
|
|
||||||
/// This method tries to decompose a type using provided decompose callback
|
|
||||||
/// functions. If it is unable to do so, the original type is returned.
|
|
||||||
void tryDecomposeType(Type, SmallVectorImpl<Type> &);
|
|
||||||
|
|
||||||
/// This method registers a callback function that will be called to decompose
|
|
||||||
/// a value of a certain type into several values.
|
|
||||||
template <typename FnT,
|
|
||||||
typename T = typename llvm::function_traits<FnT>::template arg_t<2>>
|
|
||||||
void addDecomposeValueConversion(FnT &&callback) {
|
|
||||||
decomposeValueConversions.emplace_back(
|
|
||||||
wrapDecomposeValueConversionCallback<T>(std::forward<FnT>(callback)));
|
|
||||||
}
|
|
||||||
|
|
||||||
/// This method registers a callback function that will be called to decompose
|
|
||||||
/// a type into several types.
|
|
||||||
template <typename FnT,
|
|
||||||
typename T = typename llvm::function_traits<FnT>::template arg_t<0>>
|
|
||||||
void addDecomposeTypeConversion(FnT &&callback) {
|
|
||||||
auto wrapper =
|
|
||||||
wrapDecomposeTypeConversionCallback<T>(std::forward<FnT>(callback));
|
|
||||||
decomposeTypeConversions.emplace_back(wrapper);
|
|
||||||
addConversion(std::forward<FnT>(callback));
|
|
||||||
}
|
|
||||||
|
|
||||||
/// This method returns ResultConversionKind for the mapping from `origin`
|
|
||||||
/// type to `input` type.
|
|
||||||
ResultConversionKind getResultConversionKind(Type origin, Type input);
|
|
||||||
|
|
||||||
/// This method registers ResultConversionKind for the mapping from type 'T'
|
|
||||||
/// to type 'U'.
|
|
||||||
template <typename T, typename U>
|
|
||||||
void setResultConversionKind(ResultConversionKind kind) {
|
|
||||||
assert((kind != AppendToArgumentsList ||
|
|
||||||
llvm::is_one_of<U, MemRefType, UnrankedMemRefType>::value) &&
|
|
||||||
"Only the memref typed values can be set to be appended to the "
|
|
||||||
"function argument list at the moment");
|
|
||||||
resultTypeConversions.emplace_back(
|
|
||||||
[&](Type origin, Type input) -> Optional<ResultConversionKind> {
|
|
||||||
if (origin.template isa<T>() && input.template isa<U>())
|
|
||||||
return kind;
|
|
||||||
return llvm::None;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
using DecomposeValueConversionCallFn = std::function<Optional<LogicalResult>(
|
|
||||||
OpBuilder &, Location, Type, Value, SmallVectorImpl<Value> &)>;
|
|
||||||
|
|
||||||
using DecomposeTypeConversionCallFn =
|
|
||||||
std::function<Optional<LogicalResult>(Type, SmallVectorImpl<Type> &)>;
|
|
||||||
|
|
||||||
using ResultConversionKindFn =
|
|
||||||
std::function<Optional<ResultConversionKind>(Type, Type)>;
|
|
||||||
|
|
||||||
/// Generate a wrapper for the given decompose value conversion callback.
|
|
||||||
template <typename T, typename FnT>
|
|
||||||
DecomposeValueConversionCallFn
|
|
||||||
wrapDecomposeValueConversionCallback(FnT &&callback) {
|
|
||||||
return [callback = std::forward<FnT>(callback)](
|
|
||||||
OpBuilder &builder, Location loc, Type type, Value value,
|
|
||||||
SmallVectorImpl<Value> &newValues) -> Optional<LogicalResult> {
|
|
||||||
if (T derivedType = type.dyn_cast<T>())
|
|
||||||
return callback(builder, loc, derivedType, value, newValues);
|
|
||||||
return llvm::None;
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Generate a wrapper for the given decompose type conversion callback.
|
|
||||||
template <typename T, typename FnT>
|
|
||||||
DecomposeTypeConversionCallFn
|
|
||||||
wrapDecomposeTypeConversionCallback(FnT &&callback) {
|
|
||||||
return [callback = std::forward<FnT>(callback)](
|
|
||||||
Type type,
|
|
||||||
SmallVectorImpl<Type> &results) -> Optional<LogicalResult> {
|
|
||||||
T derivedType = type.dyn_cast<T>();
|
|
||||||
if (!derivedType)
|
|
||||||
return llvm::None;
|
|
||||||
return callback(derivedType, results);
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<ResultConversionKindFn, 2> resultTypeConversions;
|
|
||||||
SmallVector<DecomposeValueConversionCallFn, 2> decomposeValueConversions;
|
|
||||||
SmallVector<DecomposeTypeConversionCallFn, 2> decomposeTypeConversions;
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Helper conversion pattern that encapsulates a BufferAssignmentPlacer
|
/// Helper conversion pattern that encapsulates a BufferAssignmentPlacer
|
||||||
/// instance. Sample usage:
|
/// instance. Sample usage:
|
||||||
/// class CustomConversionPattern : public
|
/// class CustomConversionPattern : public
|
||||||
|
@ -173,22 +68,43 @@ class BufferAssignmentOpConversionPattern
|
||||||
public:
|
public:
|
||||||
explicit BufferAssignmentOpConversionPattern(
|
explicit BufferAssignmentOpConversionPattern(
|
||||||
MLIRContext *context, BufferAssignmentPlacer *bufferAssignment = nullptr,
|
MLIRContext *context, BufferAssignmentPlacer *bufferAssignment = nullptr,
|
||||||
BufferAssignmentTypeConverter *converter = nullptr,
|
TypeConverter *converter = nullptr, PatternBenefit benefit = 1)
|
||||||
PatternBenefit benefit = 1)
|
|
||||||
: OpConversionPattern<SourceOp>(context, benefit),
|
: OpConversionPattern<SourceOp>(context, benefit),
|
||||||
bufferAssignment(bufferAssignment), converter(converter) {
|
bufferAssignment(bufferAssignment), converter(converter) {}
|
||||||
assert(converter && "The type converter has not been defined");
|
|
||||||
}
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
BufferAssignmentPlacer *bufferAssignment;
|
BufferAssignmentPlacer *bufferAssignment;
|
||||||
BufferAssignmentTypeConverter *converter;
|
TypeConverter *converter;
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Converts the signature of the function using BufferAssignmentTypeConverter.
|
/// A helper type converter class for using inside Buffer Assignment operation
|
||||||
/// Each result type of the function is kept as a function result or appended to
|
/// conversion patterns. The default constructor keeps all the types intact
|
||||||
/// the function arguments list based on ResultConversionKind for the converted
|
/// except for the ranked-tensor types which is converted to memref types.
|
||||||
/// result type.
|
class BufferAssignmentTypeConverter : public TypeConverter {
|
||||||
|
public:
|
||||||
|
BufferAssignmentTypeConverter();
|
||||||
|
|
||||||
|
/// A helper function to check if `type` has been converted from non-memref
|
||||||
|
/// type to memref.
|
||||||
|
static bool isConvertedMemref(Type type, Type before);
|
||||||
|
};
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
/// Converts the signature of the function based on whether the function is
|
||||||
|
/// allowed to return memref typed results or not using
|
||||||
|
/// `allowMemrefFunctionResults` parameter. If this option is false, then it
|
||||||
|
/// adds an extra function argument as an output buffer for each function result
|
||||||
|
/// which is going to be a memref type only after type conversion. The
|
||||||
|
/// other function result types remain unchanged. If
|
||||||
|
/// `allowMemrefFunctionResults` is true, the types are converted in place.
|
||||||
|
/// Any changes in function signature need to be applied
|
||||||
|
/// to return and caller operations. `BufferAssignmentReturnOpConverter` and
|
||||||
|
/// `BufferAssignmentCallOpConverter` are two helper function that match the
|
||||||
|
/// return and caller operation with the new function signature. Furthermore,
|
||||||
|
/// `BufferAssignmentTypeConverter` is a helper `TypeConverter` for converting
|
||||||
|
/// tensor typed values to memref typed ones.
|
||||||
|
template <bool allowMemrefFunctionResults>
|
||||||
class BufferAssignmentFuncOpConverter
|
class BufferAssignmentFuncOpConverter
|
||||||
: public BufferAssignmentOpConversionPattern<FuncOp> {
|
: public BufferAssignmentOpConversionPattern<FuncOp> {
|
||||||
public:
|
public:
|
||||||
|
@ -196,16 +112,58 @@ public:
|
||||||
FuncOp>::BufferAssignmentOpConversionPattern;
|
FuncOp>::BufferAssignmentOpConversionPattern;
|
||||||
|
|
||||||
/// Performs the actual signature rewriting step.
|
/// Performs the actual signature rewriting step.
|
||||||
LogicalResult matchAndRewrite(mlir::FuncOp, ArrayRef<Value>,
|
LogicalResult
|
||||||
ConversionPatternRewriter &) const;
|
matchAndRewrite(mlir::FuncOp funcOp, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
|
if (!converter)
|
||||||
|
return funcOp.emitError("The type converter has not been defined for "
|
||||||
|
"BufferAssignmentFuncOpConverter");
|
||||||
|
auto funcType = funcOp.getType();
|
||||||
|
|
||||||
|
// Convert function arguments using the provided TypeConverter.
|
||||||
|
TypeConverter::SignatureConversion conversion(funcType.getNumInputs());
|
||||||
|
for (auto argType : llvm::enumerate(funcType.getInputs()))
|
||||||
|
conversion.addInputs(argType.index(),
|
||||||
|
converter->convertType(argType.value()));
|
||||||
|
|
||||||
|
// If allowMemrefFunctionResults is false and a function result type is not
|
||||||
|
// a memref but it would be a memref after type conversion, a new argument
|
||||||
|
// should be appended to the function arguments list for this result.
|
||||||
|
// Otherwise, it remains unchanged as a function result.
|
||||||
|
SmallVector<Type, 2> newResultTypes;
|
||||||
|
newResultTypes.reserve(funcOp.getNumResults());
|
||||||
|
for (Type resType : funcType.getResults()) {
|
||||||
|
Type convertedType = converter->convertType(resType);
|
||||||
|
if (!allowMemrefFunctionResults &&
|
||||||
|
BufferAssignmentTypeConverter::isConvertedMemref(convertedType,
|
||||||
|
resType))
|
||||||
|
conversion.addInputs(convertedType);
|
||||||
|
else
|
||||||
|
newResultTypes.push_back(convertedType);
|
||||||
|
}
|
||||||
|
if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), *converter,
|
||||||
|
&conversion)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Update the signature of the function.
|
||||||
|
rewriter.updateRootInPlace(funcOp, [&] {
|
||||||
|
funcOp.setType(rewriter.getFunctionType(conversion.getConvertedTypes(),
|
||||||
|
newResultTypes));
|
||||||
|
});
|
||||||
|
return success();
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Rewrites the `ReturnOp` to conform with the changed function signature.
|
/// Rewrites the `ReturnOp` to conform with the changed function signature.
|
||||||
/// Operands that correspond to return values and their types have been set to
|
/// if allowMemrefFunctionResults is false, operands that correspond to return
|
||||||
/// AppendToArgumentsList are dropped. In their place, a corresponding copy
|
/// values and have been rewritten from illegal typed results to memref
|
||||||
/// operation from the operand to the target function argument is inserted.
|
/// arguments are dropped. In their place, a corresponding copy operation from
|
||||||
|
/// the operand to the output function argument is inserted. Otherwise, the
|
||||||
|
/// memref typed operands are returned.
|
||||||
|
/// Note: If this pattern rewriter is used with BufferAssignmentFuncOpConverter,
|
||||||
|
/// allowMemrefFunctionResults must be set/unset for both.
|
||||||
template <typename ReturnOpSourceTy, typename ReturnOpTargetTy,
|
template <typename ReturnOpSourceTy, typename ReturnOpTargetTy,
|
||||||
typename CopyOpTy>
|
typename CopyOpTy, bool allowMemrefFunctionResults>
|
||||||
class BufferAssignmentReturnOpConverter
|
class BufferAssignmentReturnOpConverter
|
||||||
: public BufferAssignmentOpConversionPattern<ReturnOpSourceTy> {
|
: public BufferAssignmentOpConversionPattern<ReturnOpSourceTy> {
|
||||||
public:
|
public:
|
||||||
|
@ -216,48 +174,44 @@ public:
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(ReturnOpSourceTy returnOp, ArrayRef<Value> operands,
|
matchAndRewrite(ReturnOpSourceTy returnOp, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const final {
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
Location loc = returnOp.getLoc();
|
// If the memref typed results can be returned as function results, the new
|
||||||
|
// `ReturnOp` should only return the type converted operands.
|
||||||
|
if (allowMemrefFunctionResults) {
|
||||||
|
rewriter.replaceOpWithNewOp<ReturnOpTargetTy>(returnOp, operands);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
// Split the operands depending on whether they need a copy operation or
|
// Split the operands by their kinds whether they are converted memref or
|
||||||
// they remain as operands of the return operation. If an operand is
|
// not.
|
||||||
// decomposable and a decompose callback function has been provided by the
|
SmallVector<Value, 2> needCopyOperands, newOperands;
|
||||||
// user, it will be unpacked.
|
unsigned operandsSize = operands.size();
|
||||||
SmallVector<Value, 2> newOperands, needCopyOperands;
|
needCopyOperands.reserve(operandsSize);
|
||||||
OpBuilder builder(returnOp);
|
newOperands.reserve(operandsSize);
|
||||||
for (auto operand : llvm::enumerate(operands)) {
|
for (auto operand : llvm::enumerate(operands))
|
||||||
SmallVector<Value, 2> values;
|
if (BufferAssignmentTypeConverter::isConvertedMemref(
|
||||||
this->converter->tryDecomposeValue(
|
operand.value().getType(),
|
||||||
builder, loc, operand.value().getType(), operand.value(), values);
|
returnOp.getOperand(operand.index()).getType()))
|
||||||
Type type = returnOp.getOperand(operand.index()).getType();
|
needCopyOperands.push_back(operand.value());
|
||||||
SmallVector<Type, 2> originTypes;
|
|
||||||
this->converter->tryDecomposeType(type, originTypes);
|
|
||||||
for (auto value : llvm::enumerate(values)) {
|
|
||||||
Type origin = originTypes[value.index()];
|
|
||||||
Type converted = value.value().getType();
|
|
||||||
auto kind = this->converter->getResultConversionKind(origin, converted);
|
|
||||||
if (kind == BufferAssignmentTypeConverter::KeepAsFunctionResult)
|
|
||||||
newOperands.push_back(value.value());
|
|
||||||
else
|
else
|
||||||
// kind = BufferAssignmentTypeConverter::AppendToArgumentsList
|
newOperands.push_back(operand.value());
|
||||||
needCopyOperands.push_back(value.value());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Insert Copy operations instead for the operands that have been removed
|
|
||||||
// from operand list and appended to the function arguments list.
|
|
||||||
Block &entryBlock = returnOp.getParentRegion()->front();
|
Block &entryBlock = returnOp.getParentRegion()->front();
|
||||||
unsigned numFuncArgs = entryBlock.getNumArguments();
|
unsigned numFuncArgs = entryBlock.getNumArguments();
|
||||||
if (needCopyOperands.size() > numFuncArgs)
|
|
||||||
return returnOp.emitError(
|
// Find the index of the first destination buffer.
|
||||||
"The number of operands that need Copy operations is more "
|
assert(needCopyOperands.size() <= numFuncArgs &&
|
||||||
"than the number of target function arguments.");
|
"The number of operands of return operation is more than the "
|
||||||
|
"number of function arguments.");
|
||||||
unsigned destArgNum = numFuncArgs - needCopyOperands.size();
|
unsigned destArgNum = numFuncArgs - needCopyOperands.size();
|
||||||
rewriter.setInsertionPoint(returnOp);
|
rewriter.setInsertionPoint(returnOp);
|
||||||
for (Value operand : needCopyOperands) {
|
for (Value operand : needCopyOperands) {
|
||||||
rewriter.create<CopyOpTy>(loc, operand,
|
// Insert a `CopyOp` for each converted memref-type operand.
|
||||||
|
rewriter.create<CopyOpTy>(returnOp.getLoc(), operand,
|
||||||
entryBlock.getArgument(destArgNum));
|
entryBlock.getArgument(destArgNum));
|
||||||
++destArgNum;
|
++destArgNum;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Insert the new target Return operation.
|
||||||
rewriter.replaceOpWithNewOp<ReturnOpTargetTy>(returnOp, newOperands);
|
rewriter.replaceOpWithNewOp<ReturnOpTargetTy>(returnOp, newOperands);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -265,32 +219,94 @@ public:
|
||||||
|
|
||||||
/// Rewrites the `CallOp` to match its operands and results with the signature
|
/// Rewrites the `CallOp` to match its operands and results with the signature
|
||||||
/// of the callee after rewriting the callee with
|
/// of the callee after rewriting the callee with
|
||||||
/// BufferAssignmentFuncOpConverter.
|
/// BufferAssignmentFuncOpConverter. If allowMemrefFunctionResults is false, a
|
||||||
|
/// buffer is allocated as an output buffer only for each memref typed result
|
||||||
|
/// that has been rewritten. The new allocated buffer is passed through the
|
||||||
|
/// operands list of the new `CallOp`.
|
||||||
|
/// Note: If this pattern rewriter is used with BufferAssignmentFuncOpConverter,
|
||||||
|
/// allowMemrefFunctionResults must be set/unset for both.
|
||||||
|
template <bool allowMemrefFunctionResults>
|
||||||
class BufferAssignmentCallOpConverter
|
class BufferAssignmentCallOpConverter
|
||||||
: public BufferAssignmentOpConversionPattern<CallOp> {
|
: public BufferAssignmentOpConversionPattern<CallOp> {
|
||||||
public:
|
public:
|
||||||
using BufferAssignmentOpConversionPattern<
|
using BufferAssignmentOpConversionPattern<
|
||||||
CallOp>::BufferAssignmentOpConversionPattern;
|
CallOp>::BufferAssignmentOpConversionPattern;
|
||||||
|
|
||||||
/// Performs the actual rewriting step.
|
LogicalResult
|
||||||
LogicalResult matchAndRewrite(CallOp, ArrayRef<Value>,
|
matchAndRewrite(CallOp callOp, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &) const;
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
|
if (!converter)
|
||||||
|
return callOp.emitError("The type converter has not been defined for "
|
||||||
|
"BufferAssignmentCallOpConverter");
|
||||||
|
Location loc = callOp.getLoc();
|
||||||
|
|
||||||
|
// If the memref typed results can be returned as function results, there is
|
||||||
|
// no need to create output buffers. It is only required to convert the type
|
||||||
|
// of operands and results in place for creating the new `CallOp`.
|
||||||
|
if (allowMemrefFunctionResults) {
|
||||||
|
SmallVector<Type, 2> resultTypes;
|
||||||
|
resultTypes.reserve(callOp.getNumResults());
|
||||||
|
for (Type type : callOp.getResultTypes())
|
||||||
|
resultTypes.push_back(converter->convertType(type));
|
||||||
|
rewriter.replaceOpWithNewOp<CallOp>(callOp, callOp.getCallee(),
|
||||||
|
resultTypes, operands);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value, 2> newOperands, replacingValues;
|
||||||
|
SmallVector<Type, 2> newResultTypes;
|
||||||
|
unsigned numResults = callOp.getNumResults();
|
||||||
|
newOperands.reserve(numResults + operands.size());
|
||||||
|
newOperands.append(operands.begin(), operands.end());
|
||||||
|
newResultTypes.reserve(numResults);
|
||||||
|
replacingValues.reserve(numResults);
|
||||||
|
|
||||||
|
// For each memref result of `CallOp` which has not been a memref before
|
||||||
|
// the type conversion, a new buffer is allocated and passed to the operands
|
||||||
|
// list of the new `CallOp`. Otherwise, it remains as a caller result.
|
||||||
|
for (Value result : callOp.getResults()) {
|
||||||
|
Type currType = result.getType();
|
||||||
|
Type newType = converter->convertType(result.getType());
|
||||||
|
if (BufferAssignmentTypeConverter::isConvertedMemref(newType, currType)) {
|
||||||
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
|
rewriter.restoreInsertionPoint(bufferAssignment->computeAllocPosition(
|
||||||
|
result.dyn_cast<OpResult>()));
|
||||||
|
Value alloc =
|
||||||
|
rewriter.create<AllocOp>(loc, newType.dyn_cast<MemRefType>());
|
||||||
|
newOperands.push_back(alloc);
|
||||||
|
replacingValues.push_back(alloc);
|
||||||
|
} else {
|
||||||
|
newResultTypes.push_back(currType);
|
||||||
|
|
||||||
|
// No replacing is required.
|
||||||
|
replacingValues.push_back(nullptr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Creating the new `CallOp`.
|
||||||
|
rewriter.create<CallOp>(loc, callOp.getCallee(), newResultTypes,
|
||||||
|
newOperands);
|
||||||
|
|
||||||
|
// Replacing the results of the old `CallOp`.
|
||||||
|
rewriter.replaceOp(callOp, replacingValues);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
} // end namespace detail
|
||||||
|
|
||||||
/// Populates `patterns` with the conversion patterns of buffer
|
/// Populates `patterns` with the conversion patterns of buffer
|
||||||
/// assignment.
|
/// assignment.
|
||||||
template <typename ReturnOpSourceTy, typename ReturnOpTargetTy,
|
template <typename ReturnOpSourceTy, typename ReturnOpTargetTy,
|
||||||
typename CopyOpTy>
|
typename CopyOpTy, bool allowMemrefFunctionResults>
|
||||||
static void populateWithBufferAssignmentOpConversionPatterns(
|
static void populateWithBufferAssignmentOpConversionPatterns(
|
||||||
MLIRContext *context, BufferAssignmentPlacer *placer,
|
MLIRContext *context, BufferAssignmentPlacer *placer,
|
||||||
BufferAssignmentTypeConverter *converter,
|
TypeConverter *converter, OwningRewritePatternList *patterns) {
|
||||||
OwningRewritePatternList *patterns) {
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
patterns->insert<
|
patterns->insert<
|
||||||
BufferAssignmentCallOpConverter,
|
detail::BufferAssignmentCallOpConverter<allowMemrefFunctionResults>,
|
||||||
BufferAssignmentFuncOpConverter,
|
detail::BufferAssignmentFuncOpConverter<allowMemrefFunctionResults>,
|
||||||
BufferAssignmentReturnOpConverter
|
detail::BufferAssignmentReturnOpConverter
|
||||||
<ReturnOpSourceTy, ReturnOpTargetTy, CopyOpTy>
|
<ReturnOpSourceTy, ReturnOpTargetTy, CopyOpTy, allowMemrefFunctionResults>
|
||||||
>(context, placer, converter);
|
>(context, placer, converter);
|
||||||
// clang-format on
|
// clang-format on
|
||||||
}
|
}
|
||||||
|
|
|
@ -100,11 +100,11 @@ public:
|
||||||
/// tensors to buffers.
|
/// tensors to buffers.
|
||||||
static void populateConvertLinalgOnTensorsToBuffersPattern(
|
static void populateConvertLinalgOnTensorsToBuffersPattern(
|
||||||
MLIRContext *context, BufferAssignmentPlacer *placer,
|
MLIRContext *context, BufferAssignmentPlacer *placer,
|
||||||
BufferAssignmentTypeConverter *converter,
|
TypeConverter *converter, OwningRewritePatternList *patterns) {
|
||||||
OwningRewritePatternList *patterns) {
|
|
||||||
populateWithBufferAssignmentOpConversionPatterns<
|
populateWithBufferAssignmentOpConversionPatterns<
|
||||||
mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp>(context, placer,
|
mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp,
|
||||||
converter, patterns);
|
/*allowMemrefFunctionResults=*/false>(context, placer, converter,
|
||||||
|
patterns);
|
||||||
patterns->insert<GenericOpConverter>(context, placer, converter);
|
patterns->insert<GenericOpConverter>(context, placer, converter);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -141,9 +141,6 @@ struct ConvertLinalgOnTensorsToBuffers
|
||||||
converter.isLegal(&funcOp.getBody());
|
converter.isLegal(&funcOp.getBody());
|
||||||
});
|
});
|
||||||
|
|
||||||
converter.setResultConversionKind<RankedTensorType, MemRefType>(
|
|
||||||
BufferAssignmentTypeConverter::AppendToArgumentsList);
|
|
||||||
|
|
||||||
// Walk over all the functions to apply buffer assignment.
|
// Walk over all the functions to apply buffer assignment.
|
||||||
getOperation().walk([&](FuncOp function) -> WalkResult {
|
getOperation().walk([&](FuncOp function) -> WalkResult {
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
|
|
|
@ -713,223 +713,9 @@ BufferAssignmentTypeConverter::BufferAssignmentTypeConverter() {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
/// This method tries to decompose a value of a certain type using provided
|
/// Checks if `type` has been converted from non-memref type to memref.
|
||||||
/// decompose callback functions. If it is unable to do so, the original value
|
bool BufferAssignmentTypeConverter::isConvertedMemref(Type type, Type before) {
|
||||||
/// is returned.
|
return type.isa<BaseMemRefType>() && !before.isa<BaseMemRefType>();
|
||||||
void BufferAssignmentTypeConverter::tryDecomposeValue(
|
|
||||||
OpBuilder &builder, Location loc, Type type, Value value,
|
|
||||||
SmallVectorImpl<Value> &results) {
|
|
||||||
for (auto conversion : decomposeValueConversions)
|
|
||||||
if (conversion(builder, loc, type, value, results) != llvm::None)
|
|
||||||
return;
|
|
||||||
results.push_back(value);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// This method tries to decompose a type using provided decompose callback
|
|
||||||
/// functions. If it is unable to do so, the original type is returned.
|
|
||||||
void BufferAssignmentTypeConverter::tryDecomposeType(
|
|
||||||
Type type, SmallVectorImpl<Type> &types) {
|
|
||||||
for (auto conversion : decomposeTypeConversions)
|
|
||||||
if (conversion(type, types) != llvm::None)
|
|
||||||
return;
|
|
||||||
types.push_back(type);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// This method returns ResultConversionKind for the input type.
|
|
||||||
BufferAssignmentTypeConverter::ResultConversionKind
|
|
||||||
BufferAssignmentTypeConverter::getResultConversionKind(Type origin,
|
|
||||||
Type converted) {
|
|
||||||
for (auto conversion : resultTypeConversions) {
|
|
||||||
auto res = conversion(origin, converted);
|
|
||||||
if (res != llvm::None)
|
|
||||||
return res.getValue();
|
|
||||||
}
|
|
||||||
return KeepAsFunctionResult;
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// BufferAssignmentFuncOpConverter
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
/// Performs the actual function signature rewriting step.
|
|
||||||
LogicalResult BufferAssignmentFuncOpConverter::matchAndRewrite(
|
|
||||||
mlir::FuncOp funcOp, ArrayRef<Value> operands,
|
|
||||||
ConversionPatternRewriter &rewriter) const {
|
|
||||||
auto funcType = funcOp.getType();
|
|
||||||
|
|
||||||
// Convert function arguments using the provided TypeConverter.
|
|
||||||
TypeConverter::SignatureConversion conversion(funcType.getNumInputs());
|
|
||||||
for (auto argType : llvm::enumerate(funcType.getInputs())) {
|
|
||||||
SmallVector<Type, 2> decomposedTypes, convertedTypes;
|
|
||||||
converter->tryDecomposeType(argType.value(), decomposedTypes);
|
|
||||||
converter->convertTypes(decomposedTypes, convertedTypes);
|
|
||||||
conversion.addInputs(argType.index(), convertedTypes);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert the result types of the function.
|
|
||||||
SmallVector<Type, 2> newResultTypes;
|
|
||||||
newResultTypes.reserve(funcOp.getNumResults());
|
|
||||||
for (Type resultType : funcType.getResults()) {
|
|
||||||
SmallVector<Type, 2> originTypes;
|
|
||||||
converter->tryDecomposeType(resultType, originTypes);
|
|
||||||
for (auto origin : originTypes) {
|
|
||||||
Type converted = converter->convertType(origin);
|
|
||||||
auto kind = converter->getResultConversionKind(origin, converted);
|
|
||||||
if (kind == BufferAssignmentTypeConverter::AppendToArgumentsList)
|
|
||||||
conversion.addInputs(converted);
|
|
||||||
else
|
|
||||||
// kind = BufferAssignmentTypeConverter::KeepAsFunctionResult
|
|
||||||
newResultTypes.push_back(converted);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), *converter,
|
|
||||||
&conversion)))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
// Update the signature of the function.
|
|
||||||
rewriter.updateRootInPlace(funcOp, [&] {
|
|
||||||
funcOp.setType(rewriter.getFunctionType(conversion.getConvertedTypes(),
|
|
||||||
newResultTypes));
|
|
||||||
});
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// BufferAssignmentCallOpConverter
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
/// Performs the actual rewriting step.
|
|
||||||
LogicalResult BufferAssignmentCallOpConverter::matchAndRewrite(
|
|
||||||
CallOp callOp, ArrayRef<Value> operands,
|
|
||||||
ConversionPatternRewriter &rewriter) const {
|
|
||||||
|
|
||||||
// This class represents a mapping from a result to a list of values and some
|
|
||||||
// results that have not yet constructed. Instead, the indices of these
|
|
||||||
// results in the operation that will be constructed are known. They will be
|
|
||||||
// replaced with the actual values when they are available. The order of
|
|
||||||
// adding to this mapping is important.
|
|
||||||
class ResultMapping {
|
|
||||||
public:
|
|
||||||
ResultMapping() { order = 0; };
|
|
||||||
|
|
||||||
/// Add an available value to the mapping.
|
|
||||||
void addMapping(Value value) {
|
|
||||||
toValuesMapping.push_back({order++, value});
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Add the index of unavailble result value to the mapping.
|
|
||||||
void addMapping(unsigned index) {
|
|
||||||
toIndicesMapping.push_back({order++, index});
|
|
||||||
}
|
|
||||||
|
|
||||||
/// This method returns the mapping values list. The unknown result values
|
|
||||||
/// that only their indicies are available are replaced with their values.
|
|
||||||
void getMappingValues(ValueRange valuesToReplaceIndices,
|
|
||||||
SmallVectorImpl<Value> &values) {
|
|
||||||
// Append available values to the list.
|
|
||||||
SmallVector<std::pair<unsigned, Value>, 2> res(toValuesMapping.begin(),
|
|
||||||
toValuesMapping.end());
|
|
||||||
// Replace the indices with the actual values.
|
|
||||||
llvm::for_each(
|
|
||||||
toIndicesMapping, [&](const std::pair<unsigned, unsigned> &entry) {
|
|
||||||
assert(entry.second < valuesToReplaceIndices.size() &&
|
|
||||||
"The value index is out of range.");
|
|
||||||
res.push_back({entry.first, valuesToReplaceIndices[entry.second]});
|
|
||||||
});
|
|
||||||
// Sort the values based on their adding orders.
|
|
||||||
llvm::sort(res, [](const std::pair<unsigned, Value> &v1,
|
|
||||||
const std::pair<unsigned, Value> &v2) {
|
|
||||||
return v1.first < v2.first;
|
|
||||||
});
|
|
||||||
// Fill the values.
|
|
||||||
llvm::for_each(res, [&](const std::pair<unsigned, Value> &entry) {
|
|
||||||
values.push_back(entry.second);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
/// Keeping the inserting order of mapping values.
|
|
||||||
int order;
|
|
||||||
|
|
||||||
/// Containing the mapping values with their inserting orders.
|
|
||||||
SmallVector<std::pair<unsigned, Value>, 2> toValuesMapping;
|
|
||||||
|
|
||||||
/// Containing the indices of result values with their inserting orders.
|
|
||||||
SmallVector<std::pair<unsigned, unsigned>, 2> toIndicesMapping;
|
|
||||||
};
|
|
||||||
|
|
||||||
Location loc = callOp.getLoc();
|
|
||||||
OpBuilder builder(callOp);
|
|
||||||
SmallVector<Value, 2> newOperands;
|
|
||||||
|
|
||||||
// Create the operands list of the new `CallOp`. It unpacks the decomposable
|
|
||||||
// values if a decompose callback function has been provided by the user.
|
|
||||||
for (auto operand : operands) {
|
|
||||||
SmallVector<Value, 2> values;
|
|
||||||
this->converter->tryDecomposeValue(builder, loc, operand.getType(), operand,
|
|
||||||
values);
|
|
||||||
newOperands.append(values.begin(), values.end());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create the new result types for the new `CallOp` and a mapping from the old
|
|
||||||
// result to new value(s).
|
|
||||||
SmallVector<Type, 2> newResultTypes;
|
|
||||||
SmallVector<ResultMapping, 4> mappings;
|
|
||||||
mappings.resize(callOp.getNumResults());
|
|
||||||
for (auto result : llvm::enumerate(callOp.getResults())) {
|
|
||||||
SmallVector<Type, 2> originTypes;
|
|
||||||
converter->tryDecomposeType(result.value().getType(), originTypes);
|
|
||||||
auto &resultMapping = mappings[result.index()];
|
|
||||||
for (Type origin : originTypes) {
|
|
||||||
Type converted = converter->convertType(origin);
|
|
||||||
auto kind = converter->getResultConversionKind(origin, converted);
|
|
||||||
if (kind == BufferAssignmentTypeConverter::KeepAsFunctionResult) {
|
|
||||||
newResultTypes.push_back(converted);
|
|
||||||
// The result value is not yet available. Its index is kept and it is
|
|
||||||
// replaced with the actual value of the new `CallOp` later.
|
|
||||||
resultMapping.addMapping(newResultTypes.size() - 1);
|
|
||||||
} else {
|
|
||||||
// kind = BufferAssignmentTypeConverter::AppendToArgumentsList
|
|
||||||
OpBuilder::InsertionGuard guard(rewriter);
|
|
||||||
rewriter.restoreInsertionPoint(
|
|
||||||
bufferAssignment->computeAllocPosition(result.value()));
|
|
||||||
MemRefType memref = converted.dyn_cast<MemRefType>();
|
|
||||||
if (!memref)
|
|
||||||
return callOp.emitError("Cannot allocate for a non-Memref type");
|
|
||||||
Value alloc = rewriter.create<AllocOp>(loc, memref);
|
|
||||||
newOperands.push_back(alloc);
|
|
||||||
resultMapping.addMapping(alloc);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
CallOp newCallOp = rewriter.create<CallOp>(loc, callOp.getCallee(),
|
|
||||||
newResultTypes, newOperands);
|
|
||||||
|
|
||||||
// Build a replacing value for each result to replace its uses. If a result
|
|
||||||
// has multiple mapping values, it needs to be packed to a single value.
|
|
||||||
OpBuilder nextBuilder(callOp.getOperation()->getNextNode());
|
|
||||||
SmallVector<Value, 2> replacedValues;
|
|
||||||
replacedValues.reserve(callOp.getNumResults());
|
|
||||||
for (unsigned i = 0, e = callOp.getNumResults(); i < e; ++i) {
|
|
||||||
SmallVector<Value, 2> valuesToPack;
|
|
||||||
mappings[i].getMappingValues(newCallOp.getResults(), valuesToPack);
|
|
||||||
if (valuesToPack.empty()) {
|
|
||||||
// No replacement is required.
|
|
||||||
replacedValues.push_back(nullptr);
|
|
||||||
} else if (valuesToPack.size() == 1) {
|
|
||||||
replacedValues.push_back(valuesToPack.front());
|
|
||||||
} else {
|
|
||||||
// Values need to be packed using callback function. The same callback
|
|
||||||
// that is used for materializeArgumentConversion is used for packing.
|
|
||||||
Value packed = converter->materializeArgumentConversion(
|
|
||||||
nextBuilder, loc, callOp.getType(i), valuesToPack);
|
|
||||||
replacedValues.push_back(packed);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
rewriter.replaceOp(callOp, replacedValues);
|
|
||||||
return success();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -111,73 +111,7 @@ func @caller(%arg0: tensor<5xf32>) -> tensor<5xf32> {
|
||||||
// CHECK: %[[Y:.*]]:2 = call @callee(%[[X]]#0)
|
// CHECK: %[[Y:.*]]:2 = call @callee(%[[X]]#0)
|
||||||
// CHECK: return %[[Y]]#0
|
// CHECK: return %[[Y]]#0
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
// Test case: Testing BufferAssginmnetCallOpConverter to see if it matches with the
|
|
||||||
// signature of the new signature of the callee function when there are tuple typed
|
|
||||||
// args and results. BufferAssginmentTypeConverter is set to flatten tuple typed
|
|
||||||
// arguments. The tuple typed values should be decomposed and composed using
|
|
||||||
// get_tuple_element and make_tuple operations of test dialect. Tensor types are
|
|
||||||
// converted to Memref. Memref typed function results remain as function results.
|
|
||||||
|
|
||||||
// CHECK-LABEL: func @callee
|
|
||||||
func @callee(%arg0: tuple<tensor<2xf32>,i1, tensor<5xf32>>) -> (tuple<tensor<2xf32>,i1, tensor<5xf32>>){
|
|
||||||
return %arg0 : tuple<tensor<2xf32>,i1, tensor<5xf32>>
|
|
||||||
}
|
|
||||||
// CHECK-SAME: (%[[ARG0:.*]]: memref<2xf32>, %[[ARG1:.*]]: i1, %[[ARG2:.*]]: memref<5xf32>)
|
|
||||||
// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>)
|
|
||||||
// CHECK-NEXT: %[[TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]], %[[ARG2]])
|
|
||||||
// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 0 : i32}
|
|
||||||
// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 1 : i32}
|
|
||||||
// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 2 : i32}
|
|
||||||
// CHECK-NEXT: return %[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]]
|
|
||||||
|
|
||||||
// CHECK-LABEL: func @caller
|
|
||||||
func @caller(%arg0: tuple<tensor<2xf32>,i1, tensor<5xf32>>) -> tuple<tensor<2xf32>,i1, tensor<5xf32>>{
|
|
||||||
%x0 = call @callee(%arg0) : (tuple<tensor<2xf32>,i1, tensor<5xf32>>) -> (tuple<tensor<2xf32>,i1, tensor<5xf32>>)
|
|
||||||
%y0 = call @callee(%x0) : (tuple<tensor<2xf32>,i1, tensor<5xf32>>) -> (tuple<tensor<2xf32>,i1, tensor<5xf32>>)
|
|
||||||
return %y0 : tuple<tensor<2xf32>,i1, tensor<5xf32>>
|
|
||||||
}
|
|
||||||
// CHECK-SAME: (%[[ARG0:.*]]: memref<2xf32>, %[[ARG1:.*]]: i1, %[[ARG2:.*]]: memref<5xf32>)
|
|
||||||
// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>)
|
|
||||||
// CHECK-NEXT: %[[ARG_TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]], %[[ARG2]])
|
|
||||||
// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) {index = 0 : i32}
|
|
||||||
// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) {index = 1 : i32}
|
|
||||||
// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) {index = 2 : i32}
|
|
||||||
// CHECK-NEXT: %[[CALLEE_RESULTS:.*]]:3 = call @callee(%[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]])
|
|
||||||
// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>) -> (memref<2xf32>, i1, memref<5xf32>)
|
|
||||||
// CHECK-NEXT: %[[RESULT_TUPLE:.*]] = "test.make_tuple"(%[[CALLEE_RESULTS]]#0, %[[CALLEE_RESULTS]]#1, %[[CALLEE_RESULTS]]#2)
|
|
||||||
// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[RESULT_TUPLE]]) {index = 0 : i32}
|
|
||||||
// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[RESULT_TUPLE]]) {index = 1 : i32}
|
|
||||||
// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[RESULT_TUPLE]]) {index = 2 : i32}
|
|
||||||
// CHECK-NEXT: %[[CALLEE_RESULTS:.*]]:3 = call @callee(%[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]])
|
|
||||||
// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>) -> (memref<2xf32>, i1, memref<5xf32>)
|
|
||||||
// CHECK-NEXT: %[[RETURN_TUPLE:.*]] = "test.make_tuple"(%[[CALLEE_RESULTS]]#0, %[[CALLEE_RESULTS]]#1, %[[CALLEE_RESULTS]]#2)
|
|
||||||
// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[RETURN_TUPLE]]) {index = 0 : i32}
|
|
||||||
// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[RETURN_TUPLE]]) {index = 1 : i32}
|
|
||||||
// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[RETURN_TUPLE]]) {index = 2 : i32}
|
|
||||||
// CHECK-NEXT: return %[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]]
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
// Test case: Testing BufferAssginmnetFuncOpConverter and
|
|
||||||
// BufferAssginmentReturnOpConverter to see if the return operation matches with
|
|
||||||
// the new function signature when there are tuple typed args and results.
|
|
||||||
// BufferAssginmentTypeConverter is set to flatten tuple typed arguments. The tuple
|
|
||||||
// typed values should be decomposed and composed using get_tuple_element and
|
|
||||||
// make_tuple operations of test dialect. Tensor types are converted to Memref.
|
|
||||||
// Memref typed function results remain as function results.
|
|
||||||
|
|
||||||
// CHECK-LABEL: func @decompose_tuple_typed_function_args_and_results
|
|
||||||
func @decompose_tuple_typed_function_args_and_results(%arg0: tuple<i1,f32>, %arg1: tensor<10xf32>, %arg2: tuple<i1, tensor<5xf32>>) -> (tuple<i1, tensor<5xf32>>, tensor<10xf32>, tuple<i1,f32>){
|
|
||||||
return %arg2, %arg1, %arg0 : tuple<i1, tensor<5xf32>>, tensor<10xf32>, tuple<i1,f32>
|
|
||||||
}
|
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: i1, %[[ARG1:.*]]: f32, %[[ARG2:.*]]: memref<10xf32>, %[[ARG3:.*]]: i1, %[[ARG4:.*]]: memref<5xf32>
|
|
||||||
// CHECK-SAME: (i1, memref<5xf32>, memref<10xf32>, i1, f32)
|
|
||||||
// CHECK-NEXT: %[[FIRST_TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]])
|
|
||||||
// CHECK-NEXT: %[[SECOND_TUPLE:.*]] = "test.make_tuple"(%[[ARG3]], %[[ARG4]])
|
|
||||||
// CHECK-NEXT: %[[SECOND_TUPLE_FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[SECOND_TUPLE]]) {index = 0 : i32}
|
|
||||||
// CHECK-NEXT: %[[SECOND_TUPLE_SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[SECOND_TUPLE]]) {index = 1 : i32}
|
|
||||||
// CHECK-NEXT: %[[FIRST_TUPLE_FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[FIRST_TUPLE]]) {index = 0 : i32}
|
|
||||||
// CHECK-NEXT: %[[FIRST_TUPLE_SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[FIRST_TUPLE]]) {index = 1 : i32}
|
|
||||||
// CHECK-NEXT: return %[[SECOND_TUPLE_FIRST_ELEM]], %[[SECOND_TUPLE_SECOND_ELEM]], %[[ARG2]], %[[FIRST_TUPLE_FIRST_ELEM]], %[[FIRST_TUPLE_SECOND_ELEM]]
|
|
||||||
|
|
|
@ -285,93 +285,8 @@ func @caller(%arg0: tensor<5xf32>) -> tensor<5xf32> {
|
||||||
// CHECK: linalg.copy(%[[Y0]], %[[CALLER_RESULT]])
|
// CHECK: linalg.copy(%[[Y0]], %[[CALLER_RESULT]])
|
||||||
// CHECK: return
|
// CHECK: return
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
// CHECK-LABEL: func @func_with_unranked_arg
|
// CHECK-LABEL: func @func_with_unranked_arg
|
||||||
func @func_with_unranked_arg(%arg0: tensor<*xf32>) {
|
func @func_with_unranked_arg(%arg0: tensor<*xf32>) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>)
|
// CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>)
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
// Test case: Testing BufferAssginmnetCallOpConverter to see if it matches with the
|
|
||||||
// signature of the new signature of the callee function when there are tuple typed
|
|
||||||
// args and results. BufferAssginmentTypeConverter is set to flatten tuple typed
|
|
||||||
// arguments. The tuple typed values should be decomposed and composed using
|
|
||||||
// get_tuple_element and make_tuple operations of test dialect. Tensor types are
|
|
||||||
// converted to Memref. Memref typed function results are appended to the function
|
|
||||||
// arguments list.
|
|
||||||
|
|
||||||
// CHECK-LABEL: func @callee
|
|
||||||
func @callee(%arg0: tuple<tensor<2xf32>,i1, tensor<5xf32>>) -> (tuple<tensor<2xf32>,i1, tensor<5xf32>>){
|
|
||||||
return %arg0 : tuple<tensor<2xf32>,i1, tensor<5xf32>>
|
|
||||||
}
|
|
||||||
// CHECK-SAME: (%[[ARG0:.*]]: memref<2xf32>, %[[ARG1:.*]]: i1, %[[ARG2:.*]]: memref<5xf32>, %[[RESULT0:.*]]: memref<2xf32>, %[[RESULT1:.*]]: memref<5xf32>)
|
|
||||||
// CHECK-SAME: i1
|
|
||||||
// CHECK-NEXT: %[[TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]], %[[ARG2]])
|
|
||||||
// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 0 : i32}
|
|
||||||
// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 1 : i32}
|
|
||||||
// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 2 : i32}
|
|
||||||
// CHECK-NEXT: linalg.copy(%[[FIRST_ELEM]], %[[RESULT0]])
|
|
||||||
// CHECK-NEXT: linalg.copy(%[[THIRD_ELEM]], %[[RESULT1]])
|
|
||||||
// CHECK-NEXT: return %[[SECOND_ELEM]]
|
|
||||||
|
|
||||||
|
|
||||||
// CHECK-LABEL: func @caller
|
|
||||||
func @caller(%arg0: tuple<tensor<2xf32>,i1, tensor<5xf32>>) -> tuple<tensor<2xf32>,i1, tensor<5xf32>>{
|
|
||||||
%x0 = call @callee(%arg0) : (tuple<tensor<2xf32>,i1, tensor<5xf32>>) -> (tuple<tensor<2xf32>,i1, tensor<5xf32>>)
|
|
||||||
%y0 = call @callee(%x0) : (tuple<tensor<2xf32>,i1, tensor<5xf32>>) -> (tuple<tensor<2xf32>,i1, tensor<5xf32>>)
|
|
||||||
return %y0 : tuple<tensor<2xf32>,i1, tensor<5xf32>>
|
|
||||||
}
|
|
||||||
// CHECK-SAME: (%[[ARG0:.*]]: memref<2xf32>, %[[ARG1:.*]]: i1, %[[ARG2:.*]]: memref<5xf32>, %[[RESULT0:.*]]: memref<2xf32>, %[[RESULT1:.*]]: memref<5xf32>)
|
|
||||||
// CHECK-SAME: i1
|
|
||||||
// CHECK-NEXT: %[[TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]], %[[ARG2]])
|
|
||||||
// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 0 : i32}
|
|
||||||
// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 1 : i32}
|
|
||||||
// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 2 : i32}
|
|
||||||
// CHECK-NEXT: %[[FIRST_ALLOC:.*]] = alloc()
|
|
||||||
// CHECK-NEXT: %[[SECOND_ALLOC:.*]] = alloc()
|
|
||||||
// CHECK-NEXT: %[[CALLEE_RESULT:.*]] = call @callee(%[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]], %[[FIRST_ALLOC]], %[[SECOND_ALLOC]])
|
|
||||||
// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>, memref<2xf32>, memref<5xf32>) -> i1
|
|
||||||
// CHECK-NEXT: %[[TUPLE:.*]] = "test.make_tuple"(%[[FIRST_ALLOC]], %[[CALLEE_RESULT]], %[[SECOND_ALLOC]])
|
|
||||||
// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 0 : i32}
|
|
||||||
// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 1 : i32}
|
|
||||||
// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 2 : i32}
|
|
||||||
// CHECK-NEXT: %[[FIRST_ALLOC:.*]] = alloc()
|
|
||||||
// CHECK-NEXT: %[[SECOND_ALLOC:.*]] = alloc()
|
|
||||||
// CHECK-NEXT: %[[CALLEE_RESULT:.*]] = call @callee(%[[FIRST_ELEM]], %[[SECOND_ELEM]], %[[THIRD_ELEM]], %[[FIRST_ALLOC]], %[[SECOND_ALLOC]])
|
|
||||||
// CHECK-SAME: (memref<2xf32>, i1, memref<5xf32>, memref<2xf32>, memref<5xf32>) -> i1
|
|
||||||
// CHECK-NEXT: %[[TUPLE:.*]] = "test.make_tuple"(%[[FIRST_ALLOC]], %[[CALLEE_RESULT]], %[[SECOND_ALLOC]])
|
|
||||||
// CHECK-NEXT: %[[FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 0 : i32}
|
|
||||||
// CHECK-NEXT: %[[SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 1 : i32}
|
|
||||||
// CHECK-NEXT: %[[THIRD_ELEM:.*]] = "test.get_tuple_element"(%[[TUPLE]]) {index = 2 : i32}
|
|
||||||
// CHECK-NEXT: linalg.copy(%[[FIRST_ELEM]], %[[RESULT0]])
|
|
||||||
// CHECK-NEXT: linalg.copy(%[[THIRD_ELEM]], %[[RESULT1]])
|
|
||||||
// CHECK-NEXT: return %[[SECOND_ELEM]]
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
// Test case: Testing BufferAssginmnetFuncOpConverter and
|
|
||||||
// BufferAssginmentReturnOpConverter to see if the return operation matches with
|
|
||||||
// the new function signature when there are tuple typed args and results.
|
|
||||||
// BufferAssginmentTypeConverter is set to flatten tuple typed arguments. The tuple
|
|
||||||
// typed values should be decomposed and composed using get_tuple_element and
|
|
||||||
// make_tuple operations of test dialect. Tensor types are converted to Memref.
|
|
||||||
// Memref typed function results are appended to the function arguments list.
|
|
||||||
|
|
||||||
// CHECK-LABEL: func @decompose_tuple_typed_function_args_and_results
|
|
||||||
func @decompose_tuple_typed_function_args_and_results(%arg0: tuple<i1,f32>, %arg1: tensor<10xf32>, %arg2: tuple<i1, tensor<5xf32>>) -> (tuple<i1, tensor<5xf32>>, tensor<10xf32>, tuple<i1,f32>){
|
|
||||||
return %arg2, %arg1, %arg0 : tuple<i1, tensor<5xf32>>, tensor<10xf32>, tuple<i1,f32>
|
|
||||||
}
|
|
||||||
// CHECK-SAME: %[[ARG0:.*]]: i1, %[[ARG1:.*]]: f32, %[[ARG2:.*]]: memref<10xf32>, %[[ARG3:.*]]: i1, %[[ARG4:.*]]: memref<5xf32>, %[[RESULT0:.*]]: memref<5xf32>, %[[RESULT1:.*]]: memref<10xf32>
|
|
||||||
// CHECK-SAME: (i1, i1, f32)
|
|
||||||
// CHECK-NEXT: %[[FIRST_TUPLE:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]])
|
|
||||||
// CHECK-NEXT: %[[SECOND_TUPLE:.*]] = "test.make_tuple"(%[[ARG3]], %[[ARG4]])
|
|
||||||
// CHECK-NEXT: %[[SECOND_TUPLE_FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[SECOND_TUPLE]]) {index = 0 : i32}
|
|
||||||
// CHECK-NEXT: %[[SECOND_TUPLE_SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[SECOND_TUPLE]]) {index = 1 : i32}
|
|
||||||
// CHECK-NEXT: %[[FIRST_TUPLE_FIRST_ELEM:.*]] = "test.get_tuple_element"(%[[FIRST_TUPLE]]) {index = 0 : i32}
|
|
||||||
// CHECK-NEXT: %[[FIRST_TUPLE_SECOND_ELEM:.*]] = "test.get_tuple_element"(%[[FIRST_TUPLE]]) {index = 1 : i32}
|
|
||||||
// CHECK-NEXT: linalg.copy(%[[SECOND_TUPLE_SECOND_ELEM]], %[[RESULT0]])
|
|
||||||
// CHECK-NEXT: linalg.copy(%[[ARG2]], %[[RESULT1]])
|
|
||||||
// CHECK-NEXT: return %[[SECOND_TUPLE_FIRST_ELEM]], %[[FIRST_TUPLE_FIRST_ELEM]], %[[FIRST_TUPLE_SECOND_ELEM]]
|
|
||||||
|
|
|
@ -1679,31 +1679,4 @@ def TableGenBuildOp5 : TEST_Op<"tblgen_build_5",
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// Test BufferPlacement
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
def GetTupleElementOp: TEST_Op<"get_tuple_element"> {
|
|
||||||
let description = [{
|
|
||||||
Test op that returns a specified element of the tuple.
|
|
||||||
}];
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
TupleOf<[AnyType]>,
|
|
||||||
I32Attr:$index
|
|
||||||
);
|
|
||||||
let results = (outs AnyType);
|
|
||||||
}
|
|
||||||
|
|
||||||
def MakeTupleOp: TEST_Op<"make_tuple"> {
|
|
||||||
let description = [{
|
|
||||||
Test op that creates a tuple value from a list of values.
|
|
||||||
}];
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
Variadic<AnyType>:$inputs
|
|
||||||
);
|
|
||||||
let results = (outs TupleOf<[AnyType]>);
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif // TEST_OPS
|
#endif // TEST_OPS
|
||||||
|
|
|
@ -11,8 +11,6 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "TestDialect.h"
|
|
||||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
|
||||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||||
#include "mlir/IR/Function.h"
|
#include "mlir/IR/Function.h"
|
||||||
#include "mlir/IR/Operation.h"
|
#include "mlir/IR/Operation.h"
|
||||||
|
@ -111,16 +109,14 @@ struct TestBufferPlacementPreparationPass
|
||||||
|
|
||||||
void populateTensorLinalgToBufferLinalgConversionPattern(
|
void populateTensorLinalgToBufferLinalgConversionPattern(
|
||||||
MLIRContext *context, BufferAssignmentPlacer *placer,
|
MLIRContext *context, BufferAssignmentPlacer *placer,
|
||||||
BufferAssignmentTypeConverter *converter,
|
TypeConverter *converter, OwningRewritePatternList *patterns) {
|
||||||
OwningRewritePatternList *patterns) {
|
|
||||||
populateWithBufferAssignmentOpConversionPatterns<
|
populateWithBufferAssignmentOpConversionPatterns<
|
||||||
mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp>(context, placer,
|
mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp,
|
||||||
converter, patterns);
|
allowMemrefFunctionResults>(context, placer, converter, patterns);
|
||||||
patterns->insert<GenericOpConverter>(context, placer, converter);
|
patterns->insert<GenericOpConverter>(context, placer, converter);
|
||||||
}
|
}
|
||||||
|
|
||||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
registry.insert<TestDialect>();
|
|
||||||
registry.insert<linalg::LinalgDialect>();
|
registry.insert<linalg::LinalgDialect>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -131,8 +127,6 @@ struct TestBufferPlacementPreparationPass
|
||||||
|
|
||||||
// Mark all Standard operations legal.
|
// Mark all Standard operations legal.
|
||||||
target.addLegalDialect<StandardOpsDialect>();
|
target.addLegalDialect<StandardOpsDialect>();
|
||||||
target.addLegalOp<MakeTupleOp>();
|
|
||||||
target.addLegalOp<GetTupleElementOp>();
|
|
||||||
|
|
||||||
// Mark all Linalg operations illegal as long as they work on tensors.
|
// Mark all Linalg operations illegal as long as they work on tensors.
|
||||||
auto isLegalOperation = [&](Operation *op) {
|
auto isLegalOperation = [&](Operation *op) {
|
||||||
|
@ -155,42 +149,6 @@ struct TestBufferPlacementPreparationPass
|
||||||
converter.isLegal(&funcOp.getBody());
|
converter.isLegal(&funcOp.getBody());
|
||||||
});
|
});
|
||||||
|
|
||||||
auto kind = allowMemrefFunctionResults
|
|
||||||
? BufferAssignmentTypeConverter::KeepAsFunctionResult
|
|
||||||
: BufferAssignmentTypeConverter::AppendToArgumentsList;
|
|
||||||
converter.setResultConversionKind<RankedTensorType, MemRefType>(kind);
|
|
||||||
converter.setResultConversionKind<UnrankedTensorType, UnrankedMemRefType>(
|
|
||||||
kind);
|
|
||||||
|
|
||||||
converter.addDecomposeTypeConversion(
|
|
||||||
[](TupleType tupleType, SmallVectorImpl<Type> &types) {
|
|
||||||
tupleType.getFlattenedTypes(types);
|
|
||||||
return success();
|
|
||||||
});
|
|
||||||
|
|
||||||
converter.addArgumentMaterialization(
|
|
||||||
[](OpBuilder &builder, TupleType resultType, ValueRange inputs,
|
|
||||||
Location loc) -> Optional<Value> {
|
|
||||||
if (inputs.size() == 1)
|
|
||||||
return llvm::None;
|
|
||||||
TypeRange TypeRange = inputs.getTypes();
|
|
||||||
SmallVector<Type, 2> types(TypeRange.begin(), TypeRange.end());
|
|
||||||
TupleType tuple = TupleType::get(types, builder.getContext());
|
|
||||||
mlir::Value value = builder.create<MakeTupleOp>(loc, tuple, inputs);
|
|
||||||
return value;
|
|
||||||
});
|
|
||||||
|
|
||||||
converter.addDecomposeValueConversion([](OpBuilder &builder, Location loc,
|
|
||||||
TupleType resultType, Value value,
|
|
||||||
SmallVectorImpl<Value> &values) {
|
|
||||||
for (unsigned i = 0, e = resultType.size(); i < e; ++i) {
|
|
||||||
Value res = builder.create<GetTupleElementOp>(
|
|
||||||
loc, resultType.getType(i), value, builder.getI32IntegerAttr(i));
|
|
||||||
values.push_back(res);
|
|
||||||
}
|
|
||||||
return success();
|
|
||||||
});
|
|
||||||
|
|
||||||
// Walk over all the functions to apply buffer assignment.
|
// Walk over all the functions to apply buffer assignment.
|
||||||
this->getOperation().walk([&](FuncOp function) -> WalkResult {
|
this->getOperation().walk([&](FuncOp function) -> WalkResult {
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
|
|
Loading…
Reference in New Issue