llvm-project/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

686 lines
28 KiB
C++

//===- TosaToLinalg.cpp - Lowering Tosa to Linalg Dialect -----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// These rewriters lower from the Tosa to the Linalg dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include <numeric>
using namespace mlir;
static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
}
template <typename T>
static mlir::ConstantOp
createConstFromIntAttribute(Operation *op, std::string attrName,
Type requiredAttrType, PatternRewriter &rewriter) {
auto castedN = static_cast<T>(
op->getAttr(attrName).cast<IntegerAttr>().getValue().getSExtValue());
return rewriter.create<mlir::ConstantOp>(
op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
}
template <typename T, typename P>
static mlir::SelectOp clampHelper(Operation *op, ValueRange args,
mlir::ConstantOp min, mlir::ConstantOp max,
P pred, PatternRewriter &rewriter) {
Location loc = op->getLoc();
auto smallerThanMin = rewriter.create<T>(loc, pred, args[0], min);
auto minOrArg =
rewriter.create<mlir::SelectOp>(loc, smallerThanMin, min, args[0]);
auto largerThanMax = rewriter.create<T>(loc, pred, max, args[0]);
return rewriter.create<mlir::SelectOp>(loc, largerThanMax, max, minOrArg);
}
static Value
createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
ArrayRef<Type> resultTypes,
PatternRewriter &rewriter) {
Location loc = op->getLoc();
auto elementTy =
op->getOperand(0).getType().cast<ShapedType>().getElementType();
// tosa::AbsOp
if (isa<tosa::AbsOp>(op) && elementTy.isa<FloatType>())
return rewriter.create<mlir::AbsFOp>(loc, resultTypes, args);
// tosa::AddOp
if (isa<tosa::AddOp>(op) && elementTy.isa<FloatType>())
return rewriter.create<mlir::AddFOp>(loc, resultTypes, args);
if (isa<tosa::AddOp>(op) && elementTy.isa<IntegerType>())
return rewriter.create<mlir::AddIOp>(loc, resultTypes, args);
// tosa::SubOp
if (isa<tosa::SubOp>(op) && elementTy.isa<FloatType>())
return rewriter.create<mlir::SubFOp>(loc, resultTypes, args);
if (isa<tosa::SubOp>(op) && elementTy.isa<IntegerType>())
return rewriter.create<mlir::SubIOp>(loc, resultTypes, args);
// tosa::MulOp
if (isa<tosa::MulOp>(op) && elementTy.isa<FloatType>()) {
if (dyn_cast<tosa::MulOp>(op).shift() != 0) {
(void)rewriter.notifyMatchFailure(op,
"Cannot have shift value for float");
return nullptr;
}
return rewriter.create<mlir::MulFOp>(loc, resultTypes, args);
}
if (isa<tosa::MulOp>(op) && elementTy.isa<IntegerType>()) {
auto mul =
rewriter.create<mlir::MulIOp>(loc, resultTypes, args[0], args[1]);
auto constant =
rewriter.create<mlir::ConstantOp>(loc, elementTy, op->getAttr("shift"));
return rewriter.create<mlir::SignedShiftRightOp>(loc, resultTypes, mul,
constant);
}
// tosa::NegateOp
if (isa<tosa::NegateOp>(op) && elementTy.isa<IntegerType>()) {
auto constant =
rewriter.create<mlir::ConstantOp>(loc, IntegerAttr::get(elementTy, -1));
return rewriter.create<mlir::MulIOp>(loc, resultTypes, args[0], constant);
}
if (isa<tosa::NegateOp>(op) && elementTy.isa<FloatType>())
return rewriter.create<mlir::NegFOp>(loc, resultTypes, args);
// tosa::BitwiseAndOp
if (isa<tosa::BitwiseAndOp>(op) && elementTy.isa<IntegerType>())
return rewriter.create<mlir::AndOp>(loc, resultTypes, args);
// tosa::BitwiseOrOp
if (isa<tosa::BitwiseOrOp>(op) && elementTy.isa<IntegerType>())
return rewriter.create<mlir::OrOp>(loc, resultTypes, args);
// tosa::BitwiseXOrOp
if (isa<tosa::BitwiseXorOp>(op) && elementTy.isa<IntegerType>())
return rewriter.create<mlir::XOrOp>(loc, resultTypes, args);
// tosa::LogicalLeftShiftOp
if (isa<tosa::LogicalLeftShiftOp>(op) && elementTy.isa<IntegerType>())
return rewriter.create<mlir::ShiftLeftOp>(loc, resultTypes, args);
// tosa::LogicalrightShiftOp
if (isa<tosa::LogicalRightShiftOp>(op) && elementTy.isa<IntegerType>())
return rewriter.create<mlir::UnsignedShiftRightOp>(loc, resultTypes, args);
// tosa::PowOp
if (isa<tosa::PowOp>(op) && elementTy.isa<FloatType>())
return rewriter.create<mlir::math::PowFOp>(loc, resultTypes, args);
// tosa::RsqrtOp
if (isa<tosa::RsqrtOp>(op) && elementTy.isa<FloatType>())
return rewriter.create<mlir::math::RsqrtOp>(loc, resultTypes, args);
// tosa::LogOp
if (isa<tosa::LogOp>(op) && elementTy.isa<FloatType>())
return rewriter.create<mlir::math::LogOp>(loc, resultTypes, args);
// tosa::ExpOp
if (isa<tosa::ExpOp>(op) && elementTy.isa<FloatType>())
return rewriter.create<mlir::math::ExpOp>(loc, resultTypes, args);
// tosa::TanhOp
if (isa<tosa::TanhOp>(op) && elementTy.isa<FloatType>())
return rewriter.create<mlir::math::TanhOp>(loc, resultTypes, args);
// tosa::GreaterOp
if (isa<tosa::GreaterOp>(op) && elementTy.isa<FloatType>())
return rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGT, args[0],
args[1]);
if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger())
return rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sgt, args[0],
args[1]);
// tosa::GreaterEqualOp
if (isa<tosa::GreaterEqualOp>(op) && elementTy.isa<FloatType>())
return rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGE, args[0],
args[1]);
if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger())
return rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sge, args[0],
args[1]);
// tosa::SelectOp
if (isa<tosa::SelectOp>(op)) {
elementTy = op->getOperand(1).getType().cast<ShapedType>().getElementType();
if (elementTy.isa<FloatType>() || elementTy.isa<IntegerType>())
return rewriter.create<mlir::SelectOp>(loc, args[0], args[1], args[2]);
}
// tosa::MaximumOp
if (isa<tosa::MaximumOp>(op) && elementTy.isa<FloatType>()) {
auto predicate = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGT,
args[0], args[1]);
return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
}
if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
auto predicate = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sgt,
args[0], args[1]);
return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
}
// tosa::MinimumOp
if (isa<tosa::MinimumOp>(op) && elementTy.isa<FloatType>()) {
auto predicate = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OLT,
args[0], args[1]);
return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
}
if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
auto predicate = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::slt,
args[0], args[1]);
return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
}
// tosa::CeilOp
if (isa<tosa::CeilOp>(op) && elementTy.isa<FloatType>())
return rewriter.create<mlir::CeilFOp>(loc, resultTypes, args);
// tosa::FloorOp
if (isa<tosa::FloorOp>(op) && elementTy.isa<FloatType>())
return rewriter.create<mlir::FloorFOp>(loc, resultTypes, args);
// tosa::ClampOp
if (isa<tosa::ClampOp>(op) && elementTy.isa<FloatType>()) {
auto min = rewriter.create<mlir::ConstantOp>(loc, elementTy,
op->getAttr("min_fp"));
auto max = rewriter.create<mlir::ConstantOp>(loc, elementTy,
op->getAttr("max_fp"));
return clampHelper<mlir::CmpFOp>(op, args, min, max, CmpFPredicate::OLT,
rewriter);
}
if (isa<tosa::ClampOp>(op) && elementTy.isa<IntegerType>()) {
auto min = createConstFromIntAttribute<int32_t>(op, "min_int", elementTy,
rewriter);
auto max = createConstFromIntAttribute<int32_t>(op, "max_int", elementTy,
rewriter);
return clampHelper<mlir::CmpIOp>(op, args, min, max, CmpIPredicate::slt,
rewriter);
}
// tosa::ReluNOp
if (isa<tosa::ReluNOp>(op) && elementTy.isa<FloatType>()) {
auto zero =
rewriter.create<mlir::ConstantOp>(loc, FloatAttr::get(elementTy, 0));
auto n = rewriter.create<mlir::ConstantOp>(loc, elementTy,
op->getAttr("max_fp"));
return clampHelper<mlir::CmpFOp>(op, args, zero, n, CmpFPredicate::OLT,
rewriter);
}
if (isa<tosa::ReluNOp>(op) && elementTy.isa<IntegerType>()) {
auto zero =
rewriter.create<mlir::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
auto n = createConstFromIntAttribute<int32_t>(op, "max_int", elementTy,
rewriter);
return clampHelper<mlir::CmpIOp>(op, args, zero, n, CmpIPredicate::slt,
rewriter);
}
(void)rewriter.notifyMatchFailure(
op, "unhandled op for linalg body calculation for elementwise op");
return nullptr;
}
static LogicalResult
elementwiseMatchAndRewriteHelper(Operation *operation,
PatternRewriter &rewriter) {
auto loc = operation->getLoc();
auto results = operation->getResults();
auto t0 = operation->getOperand(0).getType().template dyn_cast<ShapedType>();
if (!t0)
return rewriter.notifyMatchFailure(operation,
"All results must be a shaped type");
assert(operation->getNumResults() == 1 &&
"All TOSA elementwise ops should only return a single result.");
// Construct the indexing maps needed for linalg.generic ops.
SmallVector<Type> bodyArgTypes;
for (Value in : operation->getOperands())
bodyArgTypes.emplace_back(getElementTypeOrSelf(in.getType()));
SmallVector<Type> opResultTypes;
SmallVector<Value> initTensors;
for (auto result : results) {
auto resultTy = result.getType().template cast<ShapedType>();
if (!resultTy.hasStaticShape())
return rewriter.notifyMatchFailure(
operation,
"tosa to linalg conversion expects statically shaped tensors");
initTensors.push_back(rewriter.create<linalg::InitTensorOp>(
loc, ArrayRef<Value>({}), resultTy.getShape(),
resultTy.getElementType()));
opResultTypes.push_back(result.getType());
}
auto bodyResultTypes = llvm::to_vector<4>(llvm::map_range(
initTensors, [](Value v) { return getElementTypeOrSelf(v); }));
unsigned nloops = t0.getRank();
SmallVector<AffineMap, 2> indexingMaps;
indexingMaps.reserve(operation->getNumOperands() + bodyResultTypes.size());
// Input indexing maps may be broadcasted.
for (Type types : operation->getOperandTypes()) {
auto shape = types.cast<ShapedType>().getShape();
SmallVector<AffineExpr, 4> dimExprs;
dimExprs.reserve(nloops);
for (unsigned i = 0; i < nloops; ++i) {
// If the dimension is one we can broadcast the input with a constant
// affine expression.
if (shape[i] == 1)
dimExprs.push_back(rewriter.getAffineConstantExpr(0));
else
dimExprs.push_back(rewriter.getAffineDimExpr(i));
}
indexingMaps.push_back(AffineMap::get(/*dimCount=*/nloops,
/*symbolCount=*/0, dimExprs,
rewriter.getContext()));
}
indexingMaps.append(operation->getNumResults(),
rewriter.getMultiDimIdentityMap(nloops));
bool didEncounterError = false;
auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, opResultTypes, operation->getOperands(), initTensors, indexingMaps,
getNParallelLoopsAttrs(nloops),
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
Value opResult = createLinalgBodyCalculationForElementwiseOp(
operation, blockArgs.take_front(operation->getNumOperands()),
bodyResultTypes, rewriter);
if (opResult) {
didEncounterError = true;
}
nestedBuilder.create<linalg::YieldOp>(loc, opResult);
});
if (!didEncounterError)
return failure();
rewriter.replaceOp(operation, linalgOp->getResults());
return success();
}
// Returns the constant initial value for a given reduction operation. The
// attribute type varies depending on the element type required.
static Attribute createInitialValueForReduceOp(Operation *op, Type elementTy,
PatternRewriter &rewriter) {
if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<FloatType>())
return rewriter.getFloatAttr(elementTy, 0.0);
if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<IntegerType>())
return rewriter.getIntegerAttr(elementTy, 0);
if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<FloatType>())
return rewriter.getFloatAttr(elementTy, 1.0);
if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<IntegerType>())
return rewriter.getIntegerAttr(elementTy, 1);
if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<FloatType>())
return rewriter.getFloatAttr(
elementTy, APFloat::getLargest(
elementTy.cast<FloatType>().getFloatSemantics(), false));
if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<IntegerType>())
return rewriter.getIntegerAttr(
elementTy, APInt::getSignedMaxValue(elementTy.getIntOrFloatBitWidth()));
if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<FloatType>())
return rewriter.getFloatAttr(
elementTy, APFloat::getLargest(
elementTy.cast<FloatType>().getFloatSemantics(), true));
if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<IntegerType>())
return rewriter.getIntegerAttr(
elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
return {};
}
// Creates the body calculation for a reduction. The operations vary depending
// on the input type.
static Value createLinalgBodyCalculationForReduceOp(Operation *op,
ValueRange args,
Type elementTy,
PatternRewriter &rewriter) {
Location loc = op->getLoc();
if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<FloatType>()) {
return rewriter.create<AddFOp>(loc, args);
}
if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<IntegerType>()) {
return rewriter.create<AddIOp>(loc, args);
}
if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<FloatType>()) {
return rewriter.create<MulFOp>(loc, args);
}
if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<IntegerType>()) {
return rewriter.create<MulIOp>(loc, args);
}
if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<FloatType>()) {
auto predicate = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OLT,
args[0], args[1]);
return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
}
if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<IntegerType>()) {
auto predicate = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::slt,
args[0], args[1]);
return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
}
if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<FloatType>()) {
auto predicate = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGT,
args[0], args[1]);
return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
}
if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<IntegerType>()) {
auto predicate = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sgt,
args[0], args[1]);
return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
}
return {};
}
// Performs the match and rewrite for reduction operations. This includes
// declaring a correctly sized initial value, and the linalg.generic operation
// that reduces across the specified axis.
static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
PatternRewriter &rewriter) {
auto loc = op->getLoc();
auto inputTy = op->getOperand(0).getType().template cast<ShapedType>();
auto resultTy = op->getResult(0).getType().template cast<ShapedType>();
auto elementTy = resultTy.getElementType();
Value input = op->getOperand(0);
// First fill the output buffer with the init value.
auto initTensor = rewriter
.create<linalg::InitTensorOp>(loc, ArrayRef<Value>({}),
resultTy.getShape(),
resultTy.getElementType())
.result();
auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter);
if (!fillValueAttr)
return rewriter.notifyMatchFailure(
op, "No initial value found for reduction operation");
auto fillValue = rewriter.create<ConstantOp>(loc, fillValueAttr);
auto filledTensor =
rewriter.create<linalg::FillOp>(loc, initTensor, fillValue).result();
SmallVector<AffineExpr, 2> srcExprs;
SmallVector<AffineExpr, 2> dstExprs;
SmallVector<StringRef, 4> iteratorTypes;
for (unsigned int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
iteratorTypes.push_back(axis == i ? getReductionIteratorTypeName()
: getParallelIteratorTypeName());
if (axis != i)
dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
}
bool didEncounterError = false;
auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs});
auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, resultTy, input, filledTensor, maps, iteratorTypes,
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
auto result = createLinalgBodyCalculationForReduceOp(
op, blockArgs, elementTy, rewriter);
if (result)
didEncounterError = true;
nestedBuilder.create<linalg::YieldOp>(loc, result);
});
if (!didEncounterError)
return failure();
rewriter.replaceOp(op, linalgOp.getOperation()->getResults());
return success();
}
namespace {
template <typename SrcOp>
class PointwiseConverter : public OpRewritePattern<SrcOp> {
public:
using OpRewritePattern<SrcOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SrcOp op,
PatternRewriter &rewriter) const final {
return elementwiseMatchAndRewriteHelper(op, rewriter);
}
};
class ReshapeOpConverter : public OpConversionPattern<tosa::ReshapeOp> {
public:
using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(tosa::ReshapeOp reshape, ArrayRef<Value> args,
ConversionPatternRewriter &rewriter) const final {
typename tosa::ReshapeOp::Adaptor operands(args);
ShapedType operandTy = operands.input1().getType().cast<ShapedType>();
ShapedType resultTy = reshape.getType().template cast<ShapedType>();
if (!operandTy.hasStaticShape() || !resultTy.hasStaticShape())
return failure();
// Compute the reassociation maps for the linalg operation.
ArrayRef<int64_t> expandedShape =
(operandTy.getRank() > resultTy.getRank() ? operandTy.getShape()
: resultTy.getShape());
ArrayRef<int64_t> collapsedShape =
(operandTy.getRank() > resultTy.getRank() ? resultTy.getShape()
: operandTy.getShape());
unsigned currSrcDim = 0, currDstDim = 0;
SmallVector<linalg::ReassociationExprs, 4> reassociationMap(
collapsedShape.size());
// First scan all dimensions in the source shapes to see whether we have a
// perfect case where consecutive dimensions in source are collapsed. For
// such case we can just generate one single linalg.reshape.
bool isCollapsingSource = true;
while (currSrcDim < expandedShape.size() &&
currDstDim < collapsedShape.size()) {
int64_t dstSize = collapsedShape[currDstDim];
int64_t srcSize = expandedShape[currSrcDim];
while (srcSize < dstSize && currSrcDim < expandedShape.size()) {
reassociationMap[currDstDim].push_back(
rewriter.getAffineDimExpr(currSrcDim++));
srcSize *= expandedShape[currSrcDim];
}
if (srcSize == dstSize) {
reassociationMap[currDstDim].push_back(
rewriter.getAffineDimExpr(currSrcDim++));
// If the next dim in collapsedShape is not 1, treat subsequent dims in
// expandedShape which are 1 to be collapsed.
if (currDstDim == collapsedShape.size() - 1 ||
collapsedShape[currDstDim + 1] != 1) {
while (currSrcDim < expandedShape.size() &&
expandedShape[currSrcDim] == 1) {
reassociationMap[currDstDim].push_back(
rewriter.getAffineDimExpr(currSrcDim++));
}
}
} else {
isCollapsingSource = false;
break;
}
currDstDim++;
}
if (currSrcDim != expandedShape.size() ||
currDstDim != collapsedShape.size())
isCollapsingSource = false;
// Otherwise, we need to first reduce all source dimensions into one and
// then expand to the destination dimensions.
if (!isCollapsingSource) {
auto getIdentityExprs = [&rewriter](int n) {
SmallVector<AffineExpr, 4> exprs;
for (int i = 0; i < n; ++i)
exprs.push_back(rewriter.getAffineDimExpr(i));
return exprs;
};
Location loc = reshape.getLoc();
int64_t totalElems =
std::accumulate(expandedShape.begin(), expandedShape.end(), 1,
std::multiplies<int64_t>());
auto elemTy = operandTy.getElementType();
SmallVector<linalg::ReassociationExprs, 4> collapsingMap = {
// Use operandTy here because we need to collapse all operands
// dimensions.
getIdentityExprs(operandTy.getShape().size())};
SmallVector<linalg::ReassociationExprs, 4> expandingMap = {
// Use resultTy here because we need to expand to all result
// dimensions.
getIdentityExprs(resultTy.getShape().size())};
auto collapsedTy = RankedTensorType::get({totalElems}, elemTy);
Value collapsedOp = rewriter.create<linalg::TensorReshapeOp>(
loc, collapsedTy, args[0], collapsingMap);
rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
reshape, resultTy, collapsedOp, expandingMap);
return success();
}
rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
reshape, resultTy, args[0], reassociationMap);
return success();
}
};
class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
public:
using OpRewritePattern<tosa::TransposeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::TransposeOp op,
PatternRewriter &rewriter) const final {
DenseIntElementsAttr perms;
if (!matchPattern(op.perms(), m_Constant(&perms))) {
return failure();
}
auto resultTy = op.getType().cast<ShapedType>();
if (!resultTy.hasStaticShape())
return failure();
SmallVector<AffineExpr, 2> inputExprs;
inputExprs.resize(resultTy.getRank());
for (auto permutation : llvm::enumerate(perms.getIntValues())) {
inputExprs[permutation.value().getZExtValue()] =
rewriter.getAffineDimExpr(permutation.index());
}
auto initTensor = rewriter.create<linalg::InitTensorOp>(
op.getLoc(), ArrayRef<Value>({}), resultTy.getShape(),
resultTy.getElementType());
SmallVector<AffineMap, 2> affineMaps = {
AffineMap::get(resultTy.getRank(), /*symbolCount=*/0, inputExprs,
rewriter.getContext()),
rewriter.getMultiDimIdentityMap(resultTy.getRank())};
rewriter.replaceOpWithNewOp<linalg::GenericOp>(
op, resultTy, op.input1(), ValueRange{initTensor}, affineMaps,
getNParallelLoopsAttrs(resultTy.getRank()),
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
});
return success();
}
};
// At the codegen level any identity operations should be removed. Any cases
// where identity is load-bearing (e.g. cross device computation) should be
// handled before lowering to codegen.
template <typename SrcOp>
class IdentityNConverter : public OpRewritePattern<SrcOp> {
public:
using OpRewritePattern<SrcOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SrcOp op,
PatternRewriter &rewriter) const final {
rewriter.replaceOp(op, op.getOperation()->getOperands());
return success();
}
};
template <typename SrcOp>
class ReduceConverter : public OpRewritePattern<SrcOp> {
public:
using OpRewritePattern<SrcOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SrcOp reduceOp,
PatternRewriter &rewriter) const final {
return reduceMatchAndRewriteHelper(reduceOp, reduceOp.axis(), rewriter);
}
};
} // namespace
void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
MLIRContext *context, OwningRewritePatternList *patterns) {
patterns->insert<
PointwiseConverter<tosa::AddOp>, PointwiseConverter<tosa::SubOp>,
PointwiseConverter<tosa::MulOp>, PointwiseConverter<tosa::NegateOp>,
PointwiseConverter<tosa::PowOp>, PointwiseConverter<tosa::RsqrtOp>,
PointwiseConverter<tosa::LogOp>, PointwiseConverter<tosa::ExpOp>,
PointwiseConverter<tosa::AbsOp>, PointwiseConverter<tosa::TanhOp>,
PointwiseConverter<tosa::BitwiseAndOp>,
PointwiseConverter<tosa::BitwiseOrOp>,
PointwiseConverter<tosa::BitwiseXorOp>,
PointwiseConverter<tosa::LogicalLeftShiftOp>,
PointwiseConverter<tosa::LogicalRightShiftOp>,
PointwiseConverter<tosa::SelectOp>, PointwiseConverter<tosa::GreaterOp>,
PointwiseConverter<tosa::GreaterEqualOp>,
PointwiseConverter<tosa::MaximumOp>, PointwiseConverter<tosa::MinimumOp>,
PointwiseConverter<tosa::CeilOp>, PointwiseConverter<tosa::FloorOp>,
PointwiseConverter<tosa::ClampOp>, PointwiseConverter<tosa::ReluNOp>,
IdentityNConverter<tosa::IdentityOp>,
IdentityNConverter<tosa::IdentityNOp>, ReduceConverter<tosa::ReduceMinOp>,
ReduceConverter<tosa::ReduceMaxOp>, ReduceConverter<tosa::ReduceSumOp>,
ReduceConverter<tosa::ReduceProdOp>, ReshapeOpConverter,
TransposeConverter>(context);
}