forked from OSchip/llvm-project
686 lines
28 KiB
C++
686 lines
28 KiB
C++
//===- TosaToLinalg.cpp - Lowering Tosa to Linalg Dialect -----------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// These rewriters lower from the Tosa to the Linalg dialect.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
|
|
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
|
#include "mlir/Dialect/Math/IR/Math.h"
|
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
|
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
|
#include "mlir/IR/Matchers.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
|
|
#include <numeric>
|
|
|
|
using namespace mlir;
|
|
|
|
static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
|
|
return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
|
|
}
|
|
|
|
template <typename T>
|
|
static mlir::ConstantOp
|
|
createConstFromIntAttribute(Operation *op, std::string attrName,
|
|
Type requiredAttrType, PatternRewriter &rewriter) {
|
|
auto castedN = static_cast<T>(
|
|
op->getAttr(attrName).cast<IntegerAttr>().getValue().getSExtValue());
|
|
return rewriter.create<mlir::ConstantOp>(
|
|
op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
|
|
}
|
|
|
|
template <typename T, typename P>
|
|
static mlir::SelectOp clampHelper(Operation *op, ValueRange args,
|
|
mlir::ConstantOp min, mlir::ConstantOp max,
|
|
P pred, PatternRewriter &rewriter) {
|
|
Location loc = op->getLoc();
|
|
auto smallerThanMin = rewriter.create<T>(loc, pred, args[0], min);
|
|
auto minOrArg =
|
|
rewriter.create<mlir::SelectOp>(loc, smallerThanMin, min, args[0]);
|
|
auto largerThanMax = rewriter.create<T>(loc, pred, max, args[0]);
|
|
return rewriter.create<mlir::SelectOp>(loc, largerThanMax, max, minOrArg);
|
|
}
|
|
|
|
static Value
|
|
createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
|
|
ArrayRef<Type> resultTypes,
|
|
PatternRewriter &rewriter) {
|
|
Location loc = op->getLoc();
|
|
auto elementTy =
|
|
op->getOperand(0).getType().cast<ShapedType>().getElementType();
|
|
|
|
// tosa::AbsOp
|
|
if (isa<tosa::AbsOp>(op) && elementTy.isa<FloatType>())
|
|
return rewriter.create<mlir::AbsFOp>(loc, resultTypes, args);
|
|
|
|
// tosa::AddOp
|
|
if (isa<tosa::AddOp>(op) && elementTy.isa<FloatType>())
|
|
return rewriter.create<mlir::AddFOp>(loc, resultTypes, args);
|
|
|
|
if (isa<tosa::AddOp>(op) && elementTy.isa<IntegerType>())
|
|
return rewriter.create<mlir::AddIOp>(loc, resultTypes, args);
|
|
|
|
// tosa::SubOp
|
|
if (isa<tosa::SubOp>(op) && elementTy.isa<FloatType>())
|
|
return rewriter.create<mlir::SubFOp>(loc, resultTypes, args);
|
|
|
|
if (isa<tosa::SubOp>(op) && elementTy.isa<IntegerType>())
|
|
return rewriter.create<mlir::SubIOp>(loc, resultTypes, args);
|
|
|
|
// tosa::MulOp
|
|
if (isa<tosa::MulOp>(op) && elementTy.isa<FloatType>()) {
|
|
if (dyn_cast<tosa::MulOp>(op).shift() != 0) {
|
|
(void)rewriter.notifyMatchFailure(op,
|
|
"Cannot have shift value for float");
|
|
return nullptr;
|
|
}
|
|
return rewriter.create<mlir::MulFOp>(loc, resultTypes, args);
|
|
}
|
|
|
|
if (isa<tosa::MulOp>(op) && elementTy.isa<IntegerType>()) {
|
|
auto mul =
|
|
rewriter.create<mlir::MulIOp>(loc, resultTypes, args[0], args[1]);
|
|
auto constant =
|
|
rewriter.create<mlir::ConstantOp>(loc, elementTy, op->getAttr("shift"));
|
|
return rewriter.create<mlir::SignedShiftRightOp>(loc, resultTypes, mul,
|
|
constant);
|
|
}
|
|
|
|
// tosa::NegateOp
|
|
if (isa<tosa::NegateOp>(op) && elementTy.isa<IntegerType>()) {
|
|
auto constant =
|
|
rewriter.create<mlir::ConstantOp>(loc, IntegerAttr::get(elementTy, -1));
|
|
return rewriter.create<mlir::MulIOp>(loc, resultTypes, args[0], constant);
|
|
}
|
|
|
|
if (isa<tosa::NegateOp>(op) && elementTy.isa<FloatType>())
|
|
return rewriter.create<mlir::NegFOp>(loc, resultTypes, args);
|
|
|
|
// tosa::BitwiseAndOp
|
|
if (isa<tosa::BitwiseAndOp>(op) && elementTy.isa<IntegerType>())
|
|
return rewriter.create<mlir::AndOp>(loc, resultTypes, args);
|
|
|
|
// tosa::BitwiseOrOp
|
|
if (isa<tosa::BitwiseOrOp>(op) && elementTy.isa<IntegerType>())
|
|
return rewriter.create<mlir::OrOp>(loc, resultTypes, args);
|
|
|
|
// tosa::BitwiseXOrOp
|
|
if (isa<tosa::BitwiseXorOp>(op) && elementTy.isa<IntegerType>())
|
|
return rewriter.create<mlir::XOrOp>(loc, resultTypes, args);
|
|
|
|
// tosa::LogicalLeftShiftOp
|
|
if (isa<tosa::LogicalLeftShiftOp>(op) && elementTy.isa<IntegerType>())
|
|
return rewriter.create<mlir::ShiftLeftOp>(loc, resultTypes, args);
|
|
|
|
// tosa::LogicalrightShiftOp
|
|
if (isa<tosa::LogicalRightShiftOp>(op) && elementTy.isa<IntegerType>())
|
|
return rewriter.create<mlir::UnsignedShiftRightOp>(loc, resultTypes, args);
|
|
|
|
// tosa::PowOp
|
|
if (isa<tosa::PowOp>(op) && elementTy.isa<FloatType>())
|
|
return rewriter.create<mlir::math::PowFOp>(loc, resultTypes, args);
|
|
|
|
// tosa::RsqrtOp
|
|
if (isa<tosa::RsqrtOp>(op) && elementTy.isa<FloatType>())
|
|
return rewriter.create<mlir::math::RsqrtOp>(loc, resultTypes, args);
|
|
|
|
// tosa::LogOp
|
|
if (isa<tosa::LogOp>(op) && elementTy.isa<FloatType>())
|
|
return rewriter.create<mlir::math::LogOp>(loc, resultTypes, args);
|
|
|
|
// tosa::ExpOp
|
|
if (isa<tosa::ExpOp>(op) && elementTy.isa<FloatType>())
|
|
return rewriter.create<mlir::math::ExpOp>(loc, resultTypes, args);
|
|
|
|
// tosa::TanhOp
|
|
if (isa<tosa::TanhOp>(op) && elementTy.isa<FloatType>())
|
|
return rewriter.create<mlir::math::TanhOp>(loc, resultTypes, args);
|
|
|
|
// tosa::GreaterOp
|
|
if (isa<tosa::GreaterOp>(op) && elementTy.isa<FloatType>())
|
|
return rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGT, args[0],
|
|
args[1]);
|
|
|
|
if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger())
|
|
return rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sgt, args[0],
|
|
args[1]);
|
|
|
|
// tosa::GreaterEqualOp
|
|
if (isa<tosa::GreaterEqualOp>(op) && elementTy.isa<FloatType>())
|
|
return rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGE, args[0],
|
|
args[1]);
|
|
|
|
if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger())
|
|
return rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sge, args[0],
|
|
args[1]);
|
|
|
|
// tosa::SelectOp
|
|
if (isa<tosa::SelectOp>(op)) {
|
|
elementTy = op->getOperand(1).getType().cast<ShapedType>().getElementType();
|
|
if (elementTy.isa<FloatType>() || elementTy.isa<IntegerType>())
|
|
return rewriter.create<mlir::SelectOp>(loc, args[0], args[1], args[2]);
|
|
}
|
|
|
|
// tosa::MaximumOp
|
|
if (isa<tosa::MaximumOp>(op) && elementTy.isa<FloatType>()) {
|
|
auto predicate = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGT,
|
|
args[0], args[1]);
|
|
return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
|
|
}
|
|
|
|
if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
|
|
auto predicate = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sgt,
|
|
args[0], args[1]);
|
|
return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
|
|
}
|
|
|
|
// tosa::MinimumOp
|
|
if (isa<tosa::MinimumOp>(op) && elementTy.isa<FloatType>()) {
|
|
auto predicate = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OLT,
|
|
args[0], args[1]);
|
|
return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
|
|
}
|
|
|
|
if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
|
|
auto predicate = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::slt,
|
|
args[0], args[1]);
|
|
return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
|
|
}
|
|
|
|
// tosa::CeilOp
|
|
if (isa<tosa::CeilOp>(op) && elementTy.isa<FloatType>())
|
|
return rewriter.create<mlir::CeilFOp>(loc, resultTypes, args);
|
|
|
|
// tosa::FloorOp
|
|
if (isa<tosa::FloorOp>(op) && elementTy.isa<FloatType>())
|
|
return rewriter.create<mlir::FloorFOp>(loc, resultTypes, args);
|
|
|
|
// tosa::ClampOp
|
|
if (isa<tosa::ClampOp>(op) && elementTy.isa<FloatType>()) {
|
|
auto min = rewriter.create<mlir::ConstantOp>(loc, elementTy,
|
|
op->getAttr("min_fp"));
|
|
auto max = rewriter.create<mlir::ConstantOp>(loc, elementTy,
|
|
op->getAttr("max_fp"));
|
|
return clampHelper<mlir::CmpFOp>(op, args, min, max, CmpFPredicate::OLT,
|
|
rewriter);
|
|
}
|
|
|
|
if (isa<tosa::ClampOp>(op) && elementTy.isa<IntegerType>()) {
|
|
auto min = createConstFromIntAttribute<int32_t>(op, "min_int", elementTy,
|
|
rewriter);
|
|
auto max = createConstFromIntAttribute<int32_t>(op, "max_int", elementTy,
|
|
rewriter);
|
|
return clampHelper<mlir::CmpIOp>(op, args, min, max, CmpIPredicate::slt,
|
|
rewriter);
|
|
}
|
|
|
|
// tosa::ReluNOp
|
|
if (isa<tosa::ReluNOp>(op) && elementTy.isa<FloatType>()) {
|
|
auto zero =
|
|
rewriter.create<mlir::ConstantOp>(loc, FloatAttr::get(elementTy, 0));
|
|
auto n = rewriter.create<mlir::ConstantOp>(loc, elementTy,
|
|
op->getAttr("max_fp"));
|
|
return clampHelper<mlir::CmpFOp>(op, args, zero, n, CmpFPredicate::OLT,
|
|
rewriter);
|
|
}
|
|
|
|
if (isa<tosa::ReluNOp>(op) && elementTy.isa<IntegerType>()) {
|
|
auto zero =
|
|
rewriter.create<mlir::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
|
|
auto n = createConstFromIntAttribute<int32_t>(op, "max_int", elementTy,
|
|
rewriter);
|
|
return clampHelper<mlir::CmpIOp>(op, args, zero, n, CmpIPredicate::slt,
|
|
rewriter);
|
|
}
|
|
|
|
(void)rewriter.notifyMatchFailure(
|
|
op, "unhandled op for linalg body calculation for elementwise op");
|
|
return nullptr;
|
|
}
|
|
|
|
static LogicalResult
|
|
elementwiseMatchAndRewriteHelper(Operation *operation,
|
|
PatternRewriter &rewriter) {
|
|
auto loc = operation->getLoc();
|
|
auto results = operation->getResults();
|
|
auto t0 = operation->getOperand(0).getType().template dyn_cast<ShapedType>();
|
|
if (!t0)
|
|
return rewriter.notifyMatchFailure(operation,
|
|
"All results must be a shaped type");
|
|
|
|
assert(operation->getNumResults() == 1 &&
|
|
"All TOSA elementwise ops should only return a single result.");
|
|
|
|
// Construct the indexing maps needed for linalg.generic ops.
|
|
SmallVector<Type> bodyArgTypes;
|
|
|
|
for (Value in : operation->getOperands())
|
|
bodyArgTypes.emplace_back(getElementTypeOrSelf(in.getType()));
|
|
|
|
SmallVector<Type> opResultTypes;
|
|
SmallVector<Value> initTensors;
|
|
for (auto result : results) {
|
|
auto resultTy = result.getType().template cast<ShapedType>();
|
|
if (!resultTy.hasStaticShape())
|
|
return rewriter.notifyMatchFailure(
|
|
operation,
|
|
"tosa to linalg conversion expects statically shaped tensors");
|
|
|
|
initTensors.push_back(rewriter.create<linalg::InitTensorOp>(
|
|
loc, ArrayRef<Value>({}), resultTy.getShape(),
|
|
resultTy.getElementType()));
|
|
opResultTypes.push_back(result.getType());
|
|
}
|
|
|
|
auto bodyResultTypes = llvm::to_vector<4>(llvm::map_range(
|
|
initTensors, [](Value v) { return getElementTypeOrSelf(v); }));
|
|
|
|
unsigned nloops = t0.getRank();
|
|
SmallVector<AffineMap, 2> indexingMaps;
|
|
indexingMaps.reserve(operation->getNumOperands() + bodyResultTypes.size());
|
|
|
|
// Input indexing maps may be broadcasted.
|
|
for (Type types : operation->getOperandTypes()) {
|
|
auto shape = types.cast<ShapedType>().getShape();
|
|
SmallVector<AffineExpr, 4> dimExprs;
|
|
dimExprs.reserve(nloops);
|
|
for (unsigned i = 0; i < nloops; ++i) {
|
|
// If the dimension is one we can broadcast the input with a constant
|
|
// affine expression.
|
|
if (shape[i] == 1)
|
|
dimExprs.push_back(rewriter.getAffineConstantExpr(0));
|
|
else
|
|
dimExprs.push_back(rewriter.getAffineDimExpr(i));
|
|
}
|
|
indexingMaps.push_back(AffineMap::get(/*dimCount=*/nloops,
|
|
/*symbolCount=*/0, dimExprs,
|
|
rewriter.getContext()));
|
|
}
|
|
|
|
indexingMaps.append(operation->getNumResults(),
|
|
rewriter.getMultiDimIdentityMap(nloops));
|
|
|
|
bool didEncounterError = false;
|
|
auto linalgOp = rewriter.create<linalg::GenericOp>(
|
|
loc, opResultTypes, operation->getOperands(), initTensors, indexingMaps,
|
|
getNParallelLoopsAttrs(nloops),
|
|
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
|
|
Value opResult = createLinalgBodyCalculationForElementwiseOp(
|
|
operation, blockArgs.take_front(operation->getNumOperands()),
|
|
bodyResultTypes, rewriter);
|
|
if (opResult) {
|
|
didEncounterError = true;
|
|
}
|
|
nestedBuilder.create<linalg::YieldOp>(loc, opResult);
|
|
});
|
|
|
|
if (!didEncounterError)
|
|
return failure();
|
|
|
|
rewriter.replaceOp(operation, linalgOp->getResults());
|
|
return success();
|
|
}
|
|
|
|
// Returns the constant initial value for a given reduction operation. The
|
|
// attribute type varies depending on the element type required.
|
|
static Attribute createInitialValueForReduceOp(Operation *op, Type elementTy,
|
|
PatternRewriter &rewriter) {
|
|
if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<FloatType>())
|
|
return rewriter.getFloatAttr(elementTy, 0.0);
|
|
|
|
if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<IntegerType>())
|
|
return rewriter.getIntegerAttr(elementTy, 0);
|
|
|
|
if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<FloatType>())
|
|
return rewriter.getFloatAttr(elementTy, 1.0);
|
|
|
|
if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<IntegerType>())
|
|
return rewriter.getIntegerAttr(elementTy, 1);
|
|
|
|
if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<FloatType>())
|
|
return rewriter.getFloatAttr(
|
|
elementTy, APFloat::getLargest(
|
|
elementTy.cast<FloatType>().getFloatSemantics(), false));
|
|
|
|
if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<IntegerType>())
|
|
return rewriter.getIntegerAttr(
|
|
elementTy, APInt::getSignedMaxValue(elementTy.getIntOrFloatBitWidth()));
|
|
|
|
if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<FloatType>())
|
|
return rewriter.getFloatAttr(
|
|
elementTy, APFloat::getLargest(
|
|
elementTy.cast<FloatType>().getFloatSemantics(), true));
|
|
|
|
if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<IntegerType>())
|
|
return rewriter.getIntegerAttr(
|
|
elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
|
|
|
|
return {};
|
|
}
|
|
|
|
// Creates the body calculation for a reduction. The operations vary depending
|
|
// on the input type.
|
|
static Value createLinalgBodyCalculationForReduceOp(Operation *op,
|
|
ValueRange args,
|
|
Type elementTy,
|
|
PatternRewriter &rewriter) {
|
|
Location loc = op->getLoc();
|
|
if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<FloatType>()) {
|
|
return rewriter.create<AddFOp>(loc, args);
|
|
}
|
|
|
|
if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<IntegerType>()) {
|
|
return rewriter.create<AddIOp>(loc, args);
|
|
}
|
|
|
|
if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<FloatType>()) {
|
|
return rewriter.create<MulFOp>(loc, args);
|
|
}
|
|
|
|
if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<IntegerType>()) {
|
|
return rewriter.create<MulIOp>(loc, args);
|
|
}
|
|
|
|
if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<FloatType>()) {
|
|
auto predicate = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OLT,
|
|
args[0], args[1]);
|
|
return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
|
|
}
|
|
|
|
if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<IntegerType>()) {
|
|
auto predicate = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::slt,
|
|
args[0], args[1]);
|
|
return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
|
|
}
|
|
|
|
if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<FloatType>()) {
|
|
auto predicate = rewriter.create<mlir::CmpFOp>(loc, CmpFPredicate::OGT,
|
|
args[0], args[1]);
|
|
return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
|
|
}
|
|
|
|
if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<IntegerType>()) {
|
|
auto predicate = rewriter.create<mlir::CmpIOp>(loc, CmpIPredicate::sgt,
|
|
args[0], args[1]);
|
|
return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
|
|
}
|
|
|
|
return {};
|
|
}
|
|
|
|
// Performs the match and rewrite for reduction operations. This includes
|
|
// declaring a correctly sized initial value, and the linalg.generic operation
|
|
// that reduces across the specified axis.
|
|
static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
|
|
PatternRewriter &rewriter) {
|
|
auto loc = op->getLoc();
|
|
auto inputTy = op->getOperand(0).getType().template cast<ShapedType>();
|
|
auto resultTy = op->getResult(0).getType().template cast<ShapedType>();
|
|
auto elementTy = resultTy.getElementType();
|
|
Value input = op->getOperand(0);
|
|
|
|
// First fill the output buffer with the init value.
|
|
auto initTensor = rewriter
|
|
.create<linalg::InitTensorOp>(loc, ArrayRef<Value>({}),
|
|
resultTy.getShape(),
|
|
resultTy.getElementType())
|
|
.result();
|
|
|
|
auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter);
|
|
if (!fillValueAttr)
|
|
return rewriter.notifyMatchFailure(
|
|
op, "No initial value found for reduction operation");
|
|
|
|
auto fillValue = rewriter.create<ConstantOp>(loc, fillValueAttr);
|
|
auto filledTensor =
|
|
rewriter.create<linalg::FillOp>(loc, initTensor, fillValue).result();
|
|
|
|
SmallVector<AffineExpr, 2> srcExprs;
|
|
SmallVector<AffineExpr, 2> dstExprs;
|
|
SmallVector<StringRef, 4> iteratorTypes;
|
|
for (unsigned int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
|
|
srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
|
|
|
|
iteratorTypes.push_back(axis == i ? getReductionIteratorTypeName()
|
|
: getParallelIteratorTypeName());
|
|
if (axis != i)
|
|
dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
|
|
}
|
|
|
|
bool didEncounterError = false;
|
|
auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs});
|
|
auto linalgOp = rewriter.create<linalg::GenericOp>(
|
|
loc, resultTy, input, filledTensor, maps, iteratorTypes,
|
|
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
|
|
auto result = createLinalgBodyCalculationForReduceOp(
|
|
op, blockArgs, elementTy, rewriter);
|
|
if (result)
|
|
didEncounterError = true;
|
|
|
|
nestedBuilder.create<linalg::YieldOp>(loc, result);
|
|
});
|
|
|
|
if (!didEncounterError)
|
|
return failure();
|
|
|
|
rewriter.replaceOp(op, linalgOp.getOperation()->getResults());
|
|
return success();
|
|
}
|
|
|
|
namespace {
|
|
|
|
template <typename SrcOp>
|
|
class PointwiseConverter : public OpRewritePattern<SrcOp> {
|
|
public:
|
|
using OpRewritePattern<SrcOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(SrcOp op,
|
|
PatternRewriter &rewriter) const final {
|
|
return elementwiseMatchAndRewriteHelper(op, rewriter);
|
|
}
|
|
};
|
|
|
|
class ReshapeOpConverter : public OpConversionPattern<tosa::ReshapeOp> {
|
|
public:
|
|
using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(tosa::ReshapeOp reshape, ArrayRef<Value> args,
|
|
ConversionPatternRewriter &rewriter) const final {
|
|
typename tosa::ReshapeOp::Adaptor operands(args);
|
|
|
|
ShapedType operandTy = operands.input1().getType().cast<ShapedType>();
|
|
ShapedType resultTy = reshape.getType().template cast<ShapedType>();
|
|
|
|
if (!operandTy.hasStaticShape() || !resultTy.hasStaticShape())
|
|
return failure();
|
|
|
|
// Compute the reassociation maps for the linalg operation.
|
|
ArrayRef<int64_t> expandedShape =
|
|
(operandTy.getRank() > resultTy.getRank() ? operandTy.getShape()
|
|
: resultTy.getShape());
|
|
ArrayRef<int64_t> collapsedShape =
|
|
(operandTy.getRank() > resultTy.getRank() ? resultTy.getShape()
|
|
: operandTy.getShape());
|
|
unsigned currSrcDim = 0, currDstDim = 0;
|
|
SmallVector<linalg::ReassociationExprs, 4> reassociationMap(
|
|
collapsedShape.size());
|
|
|
|
// First scan all dimensions in the source shapes to see whether we have a
|
|
// perfect case where consecutive dimensions in source are collapsed. For
|
|
// such case we can just generate one single linalg.reshape.
|
|
bool isCollapsingSource = true;
|
|
while (currSrcDim < expandedShape.size() &&
|
|
currDstDim < collapsedShape.size()) {
|
|
int64_t dstSize = collapsedShape[currDstDim];
|
|
int64_t srcSize = expandedShape[currSrcDim];
|
|
while (srcSize < dstSize && currSrcDim < expandedShape.size()) {
|
|
reassociationMap[currDstDim].push_back(
|
|
rewriter.getAffineDimExpr(currSrcDim++));
|
|
srcSize *= expandedShape[currSrcDim];
|
|
}
|
|
if (srcSize == dstSize) {
|
|
reassociationMap[currDstDim].push_back(
|
|
rewriter.getAffineDimExpr(currSrcDim++));
|
|
// If the next dim in collapsedShape is not 1, treat subsequent dims in
|
|
// expandedShape which are 1 to be collapsed.
|
|
if (currDstDim == collapsedShape.size() - 1 ||
|
|
collapsedShape[currDstDim + 1] != 1) {
|
|
while (currSrcDim < expandedShape.size() &&
|
|
expandedShape[currSrcDim] == 1) {
|
|
reassociationMap[currDstDim].push_back(
|
|
rewriter.getAffineDimExpr(currSrcDim++));
|
|
}
|
|
}
|
|
} else {
|
|
isCollapsingSource = false;
|
|
break;
|
|
}
|
|
currDstDim++;
|
|
}
|
|
if (currSrcDim != expandedShape.size() ||
|
|
currDstDim != collapsedShape.size())
|
|
isCollapsingSource = false;
|
|
|
|
// Otherwise, we need to first reduce all source dimensions into one and
|
|
// then expand to the destination dimensions.
|
|
if (!isCollapsingSource) {
|
|
auto getIdentityExprs = [&rewriter](int n) {
|
|
SmallVector<AffineExpr, 4> exprs;
|
|
for (int i = 0; i < n; ++i)
|
|
exprs.push_back(rewriter.getAffineDimExpr(i));
|
|
return exprs;
|
|
};
|
|
Location loc = reshape.getLoc();
|
|
int64_t totalElems =
|
|
std::accumulate(expandedShape.begin(), expandedShape.end(), 1,
|
|
std::multiplies<int64_t>());
|
|
auto elemTy = operandTy.getElementType();
|
|
SmallVector<linalg::ReassociationExprs, 4> collapsingMap = {
|
|
// Use operandTy here because we need to collapse all operands
|
|
// dimensions.
|
|
getIdentityExprs(operandTy.getShape().size())};
|
|
SmallVector<linalg::ReassociationExprs, 4> expandingMap = {
|
|
// Use resultTy here because we need to expand to all result
|
|
// dimensions.
|
|
getIdentityExprs(resultTy.getShape().size())};
|
|
|
|
auto collapsedTy = RankedTensorType::get({totalElems}, elemTy);
|
|
Value collapsedOp = rewriter.create<linalg::TensorReshapeOp>(
|
|
loc, collapsedTy, args[0], collapsingMap);
|
|
rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
|
|
reshape, resultTy, collapsedOp, expandingMap);
|
|
|
|
return success();
|
|
}
|
|
|
|
rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
|
|
reshape, resultTy, args[0], reassociationMap);
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
|
|
public:
|
|
using OpRewritePattern<tosa::TransposeOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tosa::TransposeOp op,
|
|
PatternRewriter &rewriter) const final {
|
|
DenseIntElementsAttr perms;
|
|
if (!matchPattern(op.perms(), m_Constant(&perms))) {
|
|
return failure();
|
|
}
|
|
|
|
auto resultTy = op.getType().cast<ShapedType>();
|
|
if (!resultTy.hasStaticShape())
|
|
return failure();
|
|
|
|
SmallVector<AffineExpr, 2> inputExprs;
|
|
inputExprs.resize(resultTy.getRank());
|
|
for (auto permutation : llvm::enumerate(perms.getIntValues())) {
|
|
inputExprs[permutation.value().getZExtValue()] =
|
|
rewriter.getAffineDimExpr(permutation.index());
|
|
}
|
|
|
|
auto initTensor = rewriter.create<linalg::InitTensorOp>(
|
|
op.getLoc(), ArrayRef<Value>({}), resultTy.getShape(),
|
|
resultTy.getElementType());
|
|
|
|
SmallVector<AffineMap, 2> affineMaps = {
|
|
AffineMap::get(resultTy.getRank(), /*symbolCount=*/0, inputExprs,
|
|
rewriter.getContext()),
|
|
rewriter.getMultiDimIdentityMap(resultTy.getRank())};
|
|
|
|
rewriter.replaceOpWithNewOp<linalg::GenericOp>(
|
|
op, resultTy, op.input1(), ValueRange{initTensor}, affineMaps,
|
|
getNParallelLoopsAttrs(resultTy.getRank()),
|
|
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
|
|
nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
|
|
});
|
|
return success();
|
|
}
|
|
};
|
|
|
|
// At the codegen level any identity operations should be removed. Any cases
|
|
// where identity is load-bearing (e.g. cross device computation) should be
|
|
// handled before lowering to codegen.
|
|
template <typename SrcOp>
|
|
class IdentityNConverter : public OpRewritePattern<SrcOp> {
|
|
public:
|
|
using OpRewritePattern<SrcOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(SrcOp op,
|
|
PatternRewriter &rewriter) const final {
|
|
rewriter.replaceOp(op, op.getOperation()->getOperands());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
template <typename SrcOp>
|
|
class ReduceConverter : public OpRewritePattern<SrcOp> {
|
|
public:
|
|
using OpRewritePattern<SrcOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(SrcOp reduceOp,
|
|
PatternRewriter &rewriter) const final {
|
|
return reduceMatchAndRewriteHelper(reduceOp, reduceOp.axis(), rewriter);
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
|
|
MLIRContext *context, OwningRewritePatternList *patterns) {
|
|
patterns->insert<
|
|
PointwiseConverter<tosa::AddOp>, PointwiseConverter<tosa::SubOp>,
|
|
PointwiseConverter<tosa::MulOp>, PointwiseConverter<tosa::NegateOp>,
|
|
PointwiseConverter<tosa::PowOp>, PointwiseConverter<tosa::RsqrtOp>,
|
|
PointwiseConverter<tosa::LogOp>, PointwiseConverter<tosa::ExpOp>,
|
|
PointwiseConverter<tosa::AbsOp>, PointwiseConverter<tosa::TanhOp>,
|
|
PointwiseConverter<tosa::BitwiseAndOp>,
|
|
PointwiseConverter<tosa::BitwiseOrOp>,
|
|
PointwiseConverter<tosa::BitwiseXorOp>,
|
|
PointwiseConverter<tosa::LogicalLeftShiftOp>,
|
|
PointwiseConverter<tosa::LogicalRightShiftOp>,
|
|
PointwiseConverter<tosa::SelectOp>, PointwiseConverter<tosa::GreaterOp>,
|
|
PointwiseConverter<tosa::GreaterEqualOp>,
|
|
PointwiseConverter<tosa::MaximumOp>, PointwiseConverter<tosa::MinimumOp>,
|
|
PointwiseConverter<tosa::CeilOp>, PointwiseConverter<tosa::FloorOp>,
|
|
PointwiseConverter<tosa::ClampOp>, PointwiseConverter<tosa::ReluNOp>,
|
|
IdentityNConverter<tosa::IdentityOp>,
|
|
IdentityNConverter<tosa::IdentityNOp>, ReduceConverter<tosa::ReduceMinOp>,
|
|
ReduceConverter<tosa::ReduceMaxOp>, ReduceConverter<tosa::ReduceSumOp>,
|
|
ReduceConverter<tosa::ReduceProdOp>, ReshapeOpConverter,
|
|
TransposeConverter>(context);
|
|
}
|