llvm-project/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp

120 lines
4.6 KiB
C++

//===------- VectorToSPIRV.cpp - Vector to SPIRV lowering passes ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass to generate SPIRV operations for Vector
// operations.
//
//===----------------------------------------------------------------------===//
#include "../PassDetail.h"
#include "mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRV.h"
#include "mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRVPass.h"
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/SPIRVLowering.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
namespace {
struct VectorBroadcastConvert final
: public SPIRVOpLowering<vector::BroadcastOp> {
using SPIRVOpLowering<vector::BroadcastOp>::SPIRVOpLowering;
LogicalResult
matchAndRewrite(vector::BroadcastOp broadcastOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (broadcastOp.source().getType().isa<VectorType>() ||
!spirv::CompositeType::isValid(broadcastOp.getVectorType()))
return failure();
vector::BroadcastOp::Adaptor adaptor(operands);
SmallVector<Value, 4> source(broadcastOp.getVectorType().getNumElements(),
adaptor.source());
Value construct = rewriter.create<spirv::CompositeConstructOp>(
broadcastOp.getLoc(), broadcastOp.getVectorType(), source);
rewriter.replaceOp(broadcastOp, construct);
return success();
}
};
struct VectorExtractOpConvert final
: public SPIRVOpLowering<vector::ExtractOp> {
using SPIRVOpLowering<vector::ExtractOp>::SPIRVOpLowering;
LogicalResult
matchAndRewrite(vector::ExtractOp extractOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (extractOp.getType().isa<VectorType>() ||
!spirv::CompositeType::isValid(extractOp.getVectorType()))
return failure();
vector::ExtractOp::Adaptor adaptor(operands);
int32_t id = extractOp.position().begin()->cast<IntegerAttr>().getInt();
Value newExtract = rewriter.create<spirv::CompositeExtractOp>(
extractOp.getLoc(), adaptor.vector(), id);
rewriter.replaceOp(extractOp, newExtract);
return success();
}
};
struct VectorInsertOpConvert final : public SPIRVOpLowering<vector::InsertOp> {
using SPIRVOpLowering<vector::InsertOp>::SPIRVOpLowering;
LogicalResult
matchAndRewrite(vector::InsertOp insertOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (insertOp.getSourceType().isa<VectorType>() ||
!spirv::CompositeType::isValid(insertOp.getDestVectorType()))
return failure();
vector::InsertOp::Adaptor adaptor(operands);
int32_t id = insertOp.position().begin()->cast<IntegerAttr>().getInt();
Value newInsert = rewriter.create<spirv::CompositeInsertOp>(
insertOp.getLoc(), adaptor.source(), adaptor.dest(), id);
rewriter.replaceOp(insertOp, newInsert);
return success();
}
};
} // namespace
void mlir::populateVectorToSPIRVPatterns(MLIRContext *context,
SPIRVTypeConverter &typeConverter,
OwningRewritePatternList &patterns) {
patterns.insert<VectorBroadcastConvert, VectorExtractOpConvert,
VectorInsertOpConvert>(context, typeConverter);
}
namespace {
struct LowerVectorToSPIRVPass
: public ConvertVectorToSPIRVBase<LowerVectorToSPIRVPass> {
void runOnOperation() override;
};
} // namespace
void LowerVectorToSPIRVPass::runOnOperation() {
MLIRContext *context = &getContext();
ModuleOp module = getOperation();
auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
std::unique_ptr<ConversionTarget> target =
spirv::SPIRVConversionTarget::get(targetAttr);
SPIRVTypeConverter typeConverter(targetAttr);
OwningRewritePatternList patterns;
populateVectorToSPIRVPatterns(context, typeConverter, patterns);
target->addLegalOp<ModuleOp, ModuleTerminatorOp>();
target->addLegalOp<FuncOp>();
if (failed(applyFullConversion(module, *target, patterns)))
return signalPassFailure();
}
std::unique_ptr<OperationPass<ModuleOp>>
mlir::createConvertVectorToSPIRVPass() {
return std::make_unique<LowerVectorToSPIRVPass>();
}