forked from OSchip/llvm-project
614 lines
24 KiB
C++
614 lines
24 KiB
C++
//===- Detensorize.cpp - Linalg transformations as patterns ----------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "PassDetail.h"
|
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
|
#include "mlir/Dialect/Linalg/Passes.h"
|
|
#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/IR/OpDefinition.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
#include <iterator>
|
|
#include <memory>
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::linalg;
|
|
|
|
static Value sourceMaterializationCallback(OpBuilder &builder, Type type,
|
|
ValueRange inputs, Location loc) {
|
|
assert(inputs.size() == 1);
|
|
if (inputs[0].getType().isa<TensorType>())
|
|
return nullptr;
|
|
|
|
// A detensored value is converted back by creating a new tensor from its
|
|
// element(s).
|
|
auto createNewTensorOp =
|
|
builder.create<tensor::FromElementsOp>(loc, inputs[0]);
|
|
|
|
// FromElementsOp results in a tensor<1xdtype>, we need to reshape that to
|
|
// a tensor<dtype> instead.
|
|
return builder.create<tensor::CollapseShapeOp>(
|
|
loc, type, createNewTensorOp, ArrayRef<ReassociationExprs>{});
|
|
}
|
|
|
|
namespace {
|
|
/// Defines the criteria a TensorType must follow in order to be considered
|
|
/// "detensorable".
|
|
///
|
|
/// NOTE: For now, only 0-D tensors are supported.
|
|
///
|
|
/// Returns true if tensorType can be detensored.
|
|
bool canBeDetensored(TensorType tensorType) {
|
|
return tensorType.hasRank() && tensorType.getRank() == 0;
|
|
}
|
|
|
|
bool shouldBeDetensored(Operation *op, TypeConverter typeConverter) {
|
|
GenericOp genericOp = dyn_cast_or_null<GenericOp>(op);
|
|
return genericOp &&
|
|
llvm::all_of(
|
|
genericOp.getInputAndOutputOperands(), [&](OpOperand *opOperand) {
|
|
return !typeConverter.isLegal(opOperand->get().getType());
|
|
});
|
|
}
|
|
|
|
/// A conversion patttern for detensoring `linalg.generic` ops.
|
|
class DetensorizeGenericOp : public OpConversionPattern<GenericOp> {
|
|
public:
|
|
using OpConversionPattern::OpConversionPattern;
|
|
LogicalResult
|
|
matchAndRewrite(GenericOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Block *originalBlock = op->getBlock();
|
|
|
|
// Gather some information about the op before inling its region.
|
|
Block *opEntryBlock = &*op.region().begin();
|
|
YieldOp yieldOp = dyn_cast<YieldOp>(op.region().back().getTerminator());
|
|
|
|
// Split the op's region before the op. This way, we have a clear insertion
|
|
// point in which the op can be inlined.
|
|
Block *newBlock = rewriter.splitBlock(originalBlock, Block::iterator(op));
|
|
rewriter.inlineRegionBefore(op.region(), newBlock);
|
|
// Now that op's region is inlined, the operands of its YieldOp are mapped
|
|
// to the materialized target values. Therefore, we can replace the op's
|
|
// uses with those of its YielOp's operands.
|
|
rewriter.replaceOp(op, yieldOp->getOperands());
|
|
|
|
// No need for these intermediate blocks, merge them into 1.
|
|
rewriter.mergeBlocks(opEntryBlock, originalBlock, adaptor.getOperands());
|
|
rewriter.mergeBlocks(newBlock, originalBlock, {});
|
|
|
|
rewriter.eraseOp(&*Block::iterator(yieldOp));
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// A conversion pattern for detensoring internal (non-entry) blocks within a
|
|
/// function.
|
|
struct FunctionNonEntryBlockConversion : public ConversionPattern {
|
|
FunctionNonEntryBlockConversion(MLIRContext *ctx, TypeConverter &converter,
|
|
DenseSet<BlockArgument> blockArgsToDetensor)
|
|
: ConversionPattern(converter, MatchTraitOpTypeTag(),
|
|
TypeID::get<OpTrait::FunctionLike>(), /*benefit=*/1,
|
|
ctx),
|
|
blockArgsToDetensor(blockArgsToDetensor) {}
|
|
|
|
LogicalResult
|
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
rewriter.startRootUpdate(op);
|
|
Region ®ion = function_like_impl::getFunctionBody(op);
|
|
SmallVector<TypeConverter::SignatureConversion, 2> conversions;
|
|
|
|
for (Block &block : llvm::drop_begin(region, 1)) {
|
|
conversions.emplace_back(block.getNumArguments());
|
|
TypeConverter::SignatureConversion &back = conversions.back();
|
|
|
|
for (BlockArgument blockArgument : block.getArguments()) {
|
|
int idx = blockArgument.getArgNumber();
|
|
|
|
if (blockArgsToDetensor.count(blockArgument))
|
|
back.addInputs(idx, {getTypeConverter()->convertType(
|
|
block.getArgumentTypes()[idx])});
|
|
else
|
|
back.addInputs(idx, {block.getArgumentTypes()[idx]});
|
|
}
|
|
}
|
|
|
|
if (failed(rewriter.convertNonEntryRegionTypes(®ion, *typeConverter,
|
|
conversions))) {
|
|
rewriter.cancelRootUpdate(op);
|
|
return failure();
|
|
}
|
|
|
|
rewriter.finalizeRootUpdate(op);
|
|
return success();
|
|
}
|
|
|
|
private:
|
|
const DenseSet<BlockArgument> blockArgsToDetensor;
|
|
};
|
|
|
|
class DetensorizeTypeConverter : public TypeConverter {
|
|
public:
|
|
DetensorizeTypeConverter() {
|
|
addConversion([](Type type) { return type; });
|
|
|
|
// A TensorType that can be detensored, is converted to the underlying
|
|
// element type.
|
|
addConversion([](TensorType tensorType) -> Type {
|
|
if (canBeDetensored(tensorType))
|
|
return tensorType.getElementType();
|
|
|
|
return tensorType;
|
|
});
|
|
|
|
// A tensor value is detensoried by extracting its element(s).
|
|
addTargetMaterialization([](OpBuilder &builder, Type type,
|
|
ValueRange inputs, Location loc) -> Value {
|
|
return builder.create<tensor::ExtractOp>(loc, inputs[0], ValueRange{});
|
|
});
|
|
|
|
addSourceMaterialization(sourceMaterializationCallback);
|
|
addArgumentMaterialization(sourceMaterializationCallback);
|
|
}
|
|
};
|
|
|
|
/// Canonicalizes the pattern of the form
|
|
///
|
|
/// %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32>
|
|
/// %reshaped_tensor = tensor.collapse_shape %tensor []
|
|
/// : tensor<1xi32> into tensor<i32>
|
|
/// %extracted_element = tensor.extract %reshaped_tensor[] : tensor<i32>
|
|
///
|
|
/// to just %element.
|
|
struct ExtractFromReshapeFromElements
|
|
: public OpRewritePattern<tensor::ExtractOp> {
|
|
using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tensor::ExtractOp extract,
|
|
PatternRewriter &rewriter) const final {
|
|
if (!extract.indices().empty())
|
|
return failure();
|
|
|
|
auto tensorReshape =
|
|
extract.tensor().getDefiningOp<tensor::CollapseShapeOp>();
|
|
if (tensorReshape == nullptr)
|
|
return failure();
|
|
|
|
auto tensorFromElements =
|
|
tensorReshape.getOperand()
|
|
.getDefiningOp<mlir::tensor::FromElementsOp>();
|
|
if (tensorFromElements == nullptr)
|
|
return failure();
|
|
|
|
rewriter.replaceOp(extract, tensorFromElements.getOperand(0));
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// @see LinalgDetensorize in Linalg/Passes.td for more details.
|
|
struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
|
|
LinalgDetensorize() = default;
|
|
LinalgDetensorize(const LinalgDetensorize &pass)
|
|
: LinalgDetensorizeBase<LinalgDetensorize>() {}
|
|
|
|
class CostModel {
|
|
public:
|
|
virtual ~CostModel() = default;
|
|
|
|
/// A cost model algorithm computes the following outputs:
|
|
///
|
|
/// - opsToDetensor: the list of linalg ops that should be
|
|
/// detensored.
|
|
///
|
|
/// - blockArgsToDetensor: since the operands and results of detensored
|
|
/// linalg ops can cross the BB boundary (e.g. a linalg op's input can come
|
|
/// from a BB argument and a linalg op's output can be passed to successor
|
|
/// BBs), we need to maintain the sub-set of arguments that should be
|
|
/// detensored (i.e. converted by typeConverter) for each affected BB.
|
|
///
|
|
/// Example:
|
|
///
|
|
/// For the following snippet:
|
|
/// ...
|
|
/// ^bb1(%6: tensor<i32>, %9: tensor<i32>):
|
|
/// %7 = linalg.init_tensor [] : tensor<i32>
|
|
/// %8 = linalg.generic #attrs
|
|
/// ins(%6, %6 : tensor<i32>, tensor<i32>)
|
|
/// outs(%7 : tensor<i32>) {
|
|
/// ^bb0(%arg0: i32, %arg1: i32, %arg2: i32):
|
|
/// %9 = arith.addi %arg0, %arg1 : i32
|
|
/// linalg.yield %9 : i32
|
|
/// } -> tensor<i32>
|
|
/// %10 = "some.op"(%9)
|
|
/// br ^bb2(%8 : tensor<i32>)
|
|
/// ...
|
|
///
|
|
/// if the cost model decides that the linalg.generic op should be
|
|
/// detensored, then:
|
|
/// - opsToDetensor should be = {linalg.generic{add}}.
|
|
/// - blockArgsToDetensor should be = {bb1 -> {0}, bb2 -> {0}}.
|
|
virtual void compute(Operation *func,
|
|
DetensorizeTypeConverter typeConverter,
|
|
DenseSet<Operation *> &opsToDetensor,
|
|
DenseSet<BlockArgument> &blockArgsToDetensor) = 0;
|
|
|
|
/// From the blockArgsToDetensor set computed by a CostModel
|
|
/// implementation, this method computes the corresponding branch op
|
|
/// detensoring. The result is a map from a branch op to a subset of indices
|
|
/// of its operands. The indices specify which of the branch op's operands
|
|
/// should be detensored.
|
|
///
|
|
/// For the previous example, this method would compute: {bb2 -> {0}}.
|
|
static DenseMap<Operation *, DenseSet<int>> computeBranchOpDetensoring(
|
|
const DenseSet<BlockArgument> &blockArgsToDetensor) {
|
|
DenseMap<Operation *, DenseSet<int>> detensorableBranchOps;
|
|
|
|
for (auto blockArgumentElem : blockArgsToDetensor) {
|
|
Block *block = blockArgumentElem.getOwner();
|
|
|
|
for (PredecessorIterator pred = block->pred_begin();
|
|
pred != block->pred_end(); ++pred) {
|
|
BranchOpInterface terminator =
|
|
dyn_cast<BranchOpInterface>((*pred)->getTerminator());
|
|
auto blockOperands =
|
|
terminator.getSuccessorOperands(pred.getSuccessorIndex());
|
|
|
|
if (!blockOperands || blockOperands->empty())
|
|
continue;
|
|
|
|
detensorableBranchOps[terminator].insert(
|
|
blockOperands->getBeginOperandIndex() +
|
|
blockArgumentElem.getArgNumber());
|
|
}
|
|
}
|
|
|
|
return detensorableBranchOps;
|
|
}
|
|
};
|
|
|
|
/// Detensorize linalg ops involved in control-flow within a function.
|
|
///
|
|
/// This model starts from BranchOps and CondBranchOps within a function. For
|
|
/// each such branch, the model then walks the use-def chain for the branch's
|
|
/// condition backwards in order to understand where the condition's value
|
|
/// comes from. If the condition value is (indirectly) computed by a linalg op
|
|
/// that can be detensored, the model then continues walking the use-def chain
|
|
/// in order to understand where the linalg op's operands come from. This
|
|
/// leads to discovering a "detensoring component". A detensoring component is
|
|
/// the set of operations + block arguments that are involved in control-flow
|
|
/// AND can be detensored.
|
|
class ControlFlowDetectionModel : public CostModel {
|
|
public:
|
|
void compute(Operation *func, DetensorizeTypeConverter typeConverter,
|
|
DenseSet<Operation *> &opsToDetensor,
|
|
DenseSet<BlockArgument> &blockArgsToDetensor) override {
|
|
SmallVector<Value> workList;
|
|
|
|
func->walk([&](CondBranchOp condBr) {
|
|
for (auto operand : condBr.getOperands()) {
|
|
workList.push_back(operand);
|
|
}
|
|
});
|
|
|
|
func->walk([&](BranchOp br) {
|
|
for (auto operand : br.getOperands()) {
|
|
workList.push_back(operand);
|
|
}
|
|
});
|
|
|
|
DenseSet<Value> visitedValues;
|
|
DenseSet<Operation *> visitedOps;
|
|
|
|
// For a (to-be-detesored) value, check if it "escapes" the block by being
|
|
// passed to terminator. If it does, then workList is updated with the
|
|
// corresponding argument to the successor block.
|
|
auto updateWorkListWithSuccessorArguments =
|
|
[&](Value value, BranchOpInterface terminator) {
|
|
if (!terminator)
|
|
return;
|
|
|
|
for (auto operandIdx :
|
|
llvm::seq<unsigned>(0, terminator->getOperands().size())) {
|
|
Value operand = terminator->getOperand(operandIdx);
|
|
|
|
if (operand == value) {
|
|
auto succBlockArg =
|
|
terminator.getSuccessorBlockArgument(operandIdx);
|
|
|
|
if (succBlockArg && !blockArgsToDetensor.count(*succBlockArg))
|
|
workList.push_back(*succBlockArg);
|
|
}
|
|
}
|
|
};
|
|
|
|
while (!workList.empty()) {
|
|
Value currentItem = workList.pop_back_val();
|
|
|
|
if (!visitedValues.insert(currentItem).second)
|
|
continue;
|
|
|
|
// 1 - Look forward:
|
|
// 1.1 - If currentItem escapes to one or more successors, add
|
|
// the corresponding successor arguments to workList.
|
|
updateWorkListWithSuccessorArguments(
|
|
currentItem, dyn_cast<BranchOpInterface>(
|
|
currentItem.getParentBlock()->getTerminator()));
|
|
|
|
// 1.2 - For each user of currentItem, add the defined values to
|
|
// workList. This way, the user ops can be inspected later if they are
|
|
// detensorable and if so, their operands will be added to workList to
|
|
// potentially discover other parts of the detensorable component.
|
|
for (auto *user : currentItem.getUsers())
|
|
for (Value result : user->getResults())
|
|
workList.push_back(result);
|
|
|
|
// 2 - Look backward:
|
|
// 2.1 - The current item is defined by a block argument. If the owner
|
|
// block is a non-entry one, then:
|
|
// * Add the argument to blockArgsToDetensor.
|
|
// * Walk the use-def chain backwards to add each predecessor's
|
|
// terminator-operands corresponding to currentItem to workList.
|
|
if (currentItem.dyn_cast<BlockArgument>()) {
|
|
BlockArgument currentItemBlockArgument =
|
|
currentItem.cast<BlockArgument>();
|
|
Block *ownerBlock = currentItemBlockArgument.getOwner();
|
|
|
|
// Function arguments are not detensored/converted.
|
|
if (&*ownerBlock->getParent()->begin() == ownerBlock)
|
|
continue;
|
|
|
|
// This inner-block argument is involved in control-flow, it should be
|
|
// detensored.
|
|
blockArgsToDetensor.insert(currentItemBlockArgument);
|
|
|
|
for (PredecessorIterator pred = ownerBlock->pred_begin();
|
|
pred != ownerBlock->pred_end(); ++pred) {
|
|
BranchOpInterface predTerminator =
|
|
dyn_cast<BranchOpInterface>((*pred)->getTerminator());
|
|
|
|
// TODO: For now, we give up if any of the control-flow components
|
|
// in a function is not detensorable. Fix that.
|
|
if (!predTerminator) {
|
|
opsToDetensor.clear();
|
|
blockArgsToDetensor.clear();
|
|
return;
|
|
}
|
|
|
|
auto ownerBlockOperands =
|
|
predTerminator.getSuccessorOperands(pred.getSuccessorIndex());
|
|
|
|
if (!ownerBlockOperands || ownerBlockOperands->empty())
|
|
continue;
|
|
|
|
// For each predecessor, add the value it passes to that argument to
|
|
// workList to find out how it's computed.
|
|
workList.push_back(
|
|
ownerBlockOperands
|
|
.getValue()[currentItemBlockArgument.getArgNumber()]);
|
|
}
|
|
|
|
continue;
|
|
}
|
|
|
|
Operation *currentItemDefiningOp = currentItem.getDefiningOp();
|
|
|
|
if (!visitedOps.insert(currentItemDefiningOp).second)
|
|
continue;
|
|
|
|
// 2.2 - The current item is computed by a GenericOp. If the op should
|
|
// be detensored, then:
|
|
// * Add it to opsToDetensor.
|
|
// * Add its operands to workList to discover other parts of the
|
|
// potentially detensorable component.
|
|
if (auto genericOp = dyn_cast<GenericOp>(currentItemDefiningOp)) {
|
|
// The op was encountered already, no need to inspect it again.
|
|
if (opsToDetensor.count(genericOp))
|
|
continue;
|
|
|
|
// The op should not be detensored, give up on it but continue with
|
|
// discovering the rest of the control-flow component.
|
|
if (!shouldBeDetensored(genericOp, typeConverter)) {
|
|
continue;
|
|
}
|
|
|
|
opsToDetensor.insert(genericOp);
|
|
|
|
for (Value genericOpOperand : genericOp.inputs())
|
|
workList.push_back(genericOpOperand);
|
|
|
|
continue;
|
|
}
|
|
|
|
// 2.3 - The current item is the result of a FromElementsOp, it will be
|
|
// trivially detensored later as part of canonicalization patterns
|
|
// applied at the end of detensoring.
|
|
//
|
|
// Note: No need to check whether the result type of this op is
|
|
// detensorable since if it wasn't we wouldn't reach that point in the
|
|
// work list.
|
|
if (dyn_cast<tensor::FromElementsOp>(currentItemDefiningOp))
|
|
continue;
|
|
|
|
// 2.4 - The current item is the result of a scalar op, add all its
|
|
// operands to the work list.
|
|
if (llvm::all_of(
|
|
currentItemDefiningOp->getResultTypes(),
|
|
[&](Type resultType) { return resultType.isIntOrFloat(); }))
|
|
for (Value scalarOpOperand : currentItemDefiningOp->getOperands())
|
|
workList.push_back(scalarOpOperand);
|
|
}
|
|
|
|
// Since the cost model gives up on some ops (see the details of step 2.2
|
|
// above), block arguments that correspond to the values produced by those
|
|
// ops should not be detensored as well.
|
|
|
|
DenseSet<BlockArgument> blockArgsToRemove;
|
|
|
|
for (auto &blockArg : blockArgsToDetensor) {
|
|
Block *block = blockArg.getParentBlock();
|
|
|
|
// For the potentially detensorable block argument, find the
|
|
// correpsonding operands in predecessor blocks.
|
|
for (PredecessorIterator pred = block->pred_begin();
|
|
pred != block->pred_end(); ++pred) {
|
|
BranchOpInterface terminator =
|
|
dyn_cast<BranchOpInterface>((*pred)->getTerminator());
|
|
auto blockOperands =
|
|
terminator.getSuccessorOperands(pred.getSuccessorIndex());
|
|
|
|
if (!blockOperands || blockOperands->empty())
|
|
continue;
|
|
|
|
Operation *definingOp =
|
|
terminator
|
|
->getOperand(blockOperands->getBeginOperandIndex() +
|
|
blockArg.getArgNumber())
|
|
.getDefiningOp();
|
|
|
|
// If the operand is defined by a GenericOp that will not be
|
|
// detensored, then do not detensor the corresponding block argument.
|
|
if (dyn_cast_or_null<GenericOp>(definingOp) &&
|
|
opsToDetensor.count(definingOp) == 0) {
|
|
blockArgsToRemove.insert(blockArg);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
for (auto &blockArg : blockArgsToRemove) {
|
|
blockArgsToDetensor.erase(blockArg);
|
|
}
|
|
}
|
|
};
|
|
|
|
/// Detensorize everything that can detensored.
|
|
class AggressiveDetensoringModel : public CostModel {
|
|
public:
|
|
void compute(Operation *func, DetensorizeTypeConverter typeConverter,
|
|
DenseSet<Operation *> &opsToDetensor,
|
|
DenseSet<BlockArgument> &blockArgsToDetensor) override {
|
|
func->walk([&](GenericOp genericOp) {
|
|
if (shouldBeDetensored(genericOp, typeConverter))
|
|
opsToDetensor.insert(genericOp);
|
|
});
|
|
|
|
for (Block &block :
|
|
llvm::drop_begin(function_like_impl::getFunctionBody(func), 1))
|
|
for (BlockArgument blockArgument : block.getArguments())
|
|
blockArgsToDetensor.insert(blockArgument);
|
|
}
|
|
};
|
|
|
|
void runOnOperation() override {
|
|
assert(getOperation()->hasTrait<OpTrait::FunctionLike>() &&
|
|
"DetensorizePass can only be run on FunctionLike operations");
|
|
MLIRContext *context = &getContext();
|
|
DetensorizeTypeConverter typeConverter;
|
|
RewritePatternSet patterns(context);
|
|
ConversionTarget target(*context);
|
|
DenseSet<Operation *> opsToDetensor;
|
|
DenseMap<Operation *, DenseSet<int>> detensorableBranchOps;
|
|
DenseSet<BlockArgument> blockArgsToDetensor;
|
|
|
|
if (aggressiveMode.getValue()) {
|
|
AggressiveDetensoringModel costModel;
|
|
costModel.compute(getOperation(), typeConverter, opsToDetensor,
|
|
blockArgsToDetensor);
|
|
|
|
} else {
|
|
ControlFlowDetectionModel costModel;
|
|
costModel.compute(getOperation(), typeConverter, opsToDetensor,
|
|
blockArgsToDetensor);
|
|
}
|
|
|
|
detensorableBranchOps =
|
|
CostModel::computeBranchOpDetensoring(blockArgsToDetensor);
|
|
|
|
target.addDynamicallyLegalOp<GenericOp>(
|
|
[&](GenericOp op) { return !opsToDetensor.count(op); });
|
|
|
|
target.markUnknownOpDynamicallyLegal([&](Operation *op) {
|
|
// A function is legal if all of its non-entry blocks are legal. We
|
|
// don't legalize the entry block (i.e. the function's signature)
|
|
// since detensoring can't happen along external calling convention
|
|
// boundaries, which we conservatively approximate as all function
|
|
// signatures.
|
|
if (op->hasTrait<OpTrait::FunctionLike>()) {
|
|
auto &body = function_like_impl::getFunctionBody(op);
|
|
return llvm::all_of(llvm::drop_begin(body, 1), [&](Block &block) {
|
|
if (llvm::any_of(
|
|
blockArgsToDetensor, [&](BlockArgument blockArgument) {
|
|
return blockArgument.getOwner() == &block &&
|
|
!typeConverter.isLegal(blockArgument.getType());
|
|
})) {
|
|
return false;
|
|
}
|
|
return true;
|
|
});
|
|
}
|
|
|
|
if (isNotBranchOpInterfaceOrReturnLikeOp(op) ||
|
|
isLegalForReturnOpTypeConversionPattern(op, typeConverter,
|
|
/*returnOpAlwaysLegal*/ true))
|
|
return true;
|
|
|
|
if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
|
|
if (!detensorableBranchOps.count(branchOp))
|
|
return true;
|
|
|
|
for (auto operandIdx : detensorableBranchOps[branchOp])
|
|
if (!typeConverter.isLegal(
|
|
branchOp->getOperand(operandIdx).getType()))
|
|
return false;
|
|
|
|
return true;
|
|
}
|
|
|
|
return false;
|
|
});
|
|
|
|
patterns.insert<DetensorizeGenericOp>(typeConverter, context);
|
|
patterns.insert<FunctionNonEntryBlockConversion>(context, typeConverter,
|
|
blockArgsToDetensor);
|
|
// Since non-entry block arguments get detensorized, we also need to
|
|
// update the control flow inside the function to reflect the correct
|
|
// types.
|
|
auto shouldConvertBranchOperand = [&](BranchOpInterface branchOp,
|
|
int operandIdx) -> bool {
|
|
return detensorableBranchOps.count(branchOp) &&
|
|
detensorableBranchOps[branchOp].count(operandIdx);
|
|
};
|
|
|
|
populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter,
|
|
shouldConvertBranchOperand);
|
|
|
|
if (failed(
|
|
applyFullConversion(getOperation(), target, std::move(patterns))))
|
|
signalPassFailure();
|
|
|
|
RewritePatternSet canonPatterns(context);
|
|
canonPatterns.add<ExtractFromReshapeFromElements>(context);
|
|
if (failed(applyPatternsAndFoldGreedily(getOperation(),
|
|
std::move(canonPatterns))))
|
|
signalPassFailure();
|
|
}
|
|
|
|
Option<bool> aggressiveMode{
|
|
*this, "aggressive-mode",
|
|
llvm::cl::desc("Detensorize all ops that qualify for detensoring along "
|
|
"with branch operands and basic-block arguments.")};
|
|
};
|
|
} // namespace
|
|
|
|
std::unique_ptr<Pass> mlir::createLinalgDetensorizePass() {
|
|
return std::make_unique<LinalgDetensorize>();
|
|
}
|