forked from OSchip/llvm-project
[mlir][shape] Add a pattern to rewrite `shape.reduce` as `scf.for`.
Differential Revision: https://reviews.llvm.org/D81694
This commit is contained in:
parent
e1741e34e0
commit
3813f24e97
|
@ -206,6 +206,15 @@ def ConvertShapeToStandard : Pass<"convert-shape-to-std", "ModuleOp"> {
|
|||
let constructor = "mlir::createConvertShapeToStandardPass()";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ShapeToSCF
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def ConvertShapeToSCF : FunctionPass<"convert-shape-to-scf"> {
|
||||
let summary = "Convert operations from the shape dialect to the SCF dialect";
|
||||
let constructor = "mlir::createConvertShapeToSCFPass()";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SPIRVToLLVM
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
//===- ShapeToSCF.h - Conversion utils from Shape to SCF dialect ----------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_CONVERSION_SHAPETOSCF_SHAPETOSCF_H_
|
||||
#define MLIR_CONVERSION_SHAPETOSCF_SHAPETOSCF_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class MLIRContext;
|
||||
class FunctionPass;
|
||||
class OwningRewritePatternList;
|
||||
|
||||
void populateShapeToSCFConversionPatterns(OwningRewritePatternList &patterns,
|
||||
MLIRContext *ctx);
|
||||
|
||||
std::unique_ptr<FunctionPass> createConvertShapeToSCFPass();
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_CONVERSION_SHAPETOSCF_SHAPETOSCF_H_
|
|
@ -26,6 +26,7 @@
|
|||
#include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h"
|
||||
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
|
||||
#include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.h"
|
||||
#include "mlir/Conversion/ShapeToSCF/ShapeToSCF.h"
|
||||
#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
||||
#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h"
|
||||
|
|
|
@ -10,6 +10,7 @@ add_subdirectory(LinalgToSPIRV)
|
|||
add_subdirectory(LinalgToStandard)
|
||||
add_subdirectory(SCFToGPU)
|
||||
add_subdirectory(SCFToStandard)
|
||||
add_subdirectory(ShapeToSCF)
|
||||
add_subdirectory(ShapeToStandard)
|
||||
add_subdirectory(SPIRVToLLVM)
|
||||
add_subdirectory(StandardToLLVM)
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
add_mlir_conversion_library(MLIRShapeToSCF
|
||||
ShapeToSCF.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ShapeToSCF
|
||||
|
||||
DEPENDS
|
||||
MLIRConversionPassIncGen
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRShape
|
||||
MLIRPass
|
||||
MLIRSCF
|
||||
MLIRTransforms
|
||||
)
|
|
@ -0,0 +1,99 @@
|
|||
//===- ShapeToSCF.cpp - conversion from Shape to SCF dialect --------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Conversion/ShapeToSCF/ShapeToSCF.h"
|
||||
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::shape;
|
||||
|
||||
namespace {
|
||||
/// Converts `shape.reduce` to `scf.for`.
|
||||
struct ReduceOpConverter : public OpRewritePattern<ReduceOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ReduceOp op,
|
||||
PatternRewriter &rewriter) const final;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
LogicalResult
|
||||
ReduceOpConverter::matchAndRewrite(ReduceOp reduceOp,
|
||||
PatternRewriter &rewriter) const {
|
||||
auto loc = reduceOp.getLoc();
|
||||
|
||||
Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
|
||||
Value one = rewriter.create<ConstantIndexOp>(loc, 1);
|
||||
Value extentTensor = rewriter.create<ToExtentTensorOp>(
|
||||
loc,
|
||||
RankedTensorType::get({ShapedType::kDynamicSize},
|
||||
rewriter.getIndexType()),
|
||||
reduceOp.shape());
|
||||
Value size =
|
||||
rewriter.create<DimOp>(loc, rewriter.getIndexType(), extentTensor, zero);
|
||||
|
||||
auto loop = rewriter.create<scf::ForOp>(
|
||||
loc, zero, size, one, reduceOp.initVals(),
|
||||
[&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) {
|
||||
Value indexExtent = b.create<ExtractElementOp>(loc, extentTensor, iv);
|
||||
Value sizeExtent = b.create<IndexToSizeOp>(loc, indexExtent);
|
||||
|
||||
SmallVector<Value, 2> mapped_values{iv, sizeExtent};
|
||||
mapped_values.append(args.begin(), args.end());
|
||||
|
||||
BlockAndValueMapping mapping;
|
||||
Block *reduceBody = reduceOp.getBody();
|
||||
mapping.map(reduceBody->getArguments(), mapped_values);
|
||||
for (auto &nested : reduceBody->without_terminator())
|
||||
b.clone(nested, mapping);
|
||||
|
||||
SmallVector<Value, 2> mappedResults;
|
||||
for (auto result : reduceBody->getTerminator()->getOperands())
|
||||
mappedResults.push_back(mapping.lookup(result));
|
||||
b.create<scf::YieldOp>(loc, mappedResults);
|
||||
});
|
||||
|
||||
rewriter.replaceOp(reduceOp, loop.getResults());
|
||||
return success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct ConvertShapeToSCFPass
|
||||
: public ConvertShapeToSCFBase<ConvertShapeToSCFPass> {
|
||||
void runOnFunction() override;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void ConvertShapeToSCFPass::runOnFunction() {
|
||||
MLIRContext &ctx = getContext();
|
||||
|
||||
OwningRewritePatternList patterns;
|
||||
populateShapeToSCFConversionPatterns(patterns, &ctx);
|
||||
|
||||
ConversionTarget target(getContext());
|
||||
target.addLegalDialect<ShapeDialect, scf::SCFDialect, StandardOpsDialect>();
|
||||
target.addIllegalOp<ReduceOp>();
|
||||
if (failed(mlir::applyPartialConversion(getFunction(), target, patterns)))
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
void mlir::populateShapeToSCFConversionPatterns(
|
||||
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
||||
patterns.insert<ReduceOpConverter>(ctx);
|
||||
}
|
||||
|
||||
std::unique_ptr<FunctionPass> mlir::createConvertShapeToSCFPass() {
|
||||
return std::make_unique<ConvertShapeToSCFPass>();
|
||||
}
|
|
@ -0,0 +1,28 @@
|
|||
// RUN: mlir-opt -convert-shape-to-scf -split-input-file %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: shape_reduce
|
||||
// CHECK-SAME: [[SHAPE:%.*]]: !shape.shape) -> !shape.size {
|
||||
func @shape_reduce(%shape : !shape.shape) -> !shape.size {
|
||||
%init = shape.const_size 1
|
||||
%num_elements = shape.reduce(%shape, %init) -> !shape.size {
|
||||
^bb0(%index: index, %dim: !shape.size, %acc: !shape.size):
|
||||
%new_acc = shape.mul %acc, %dim
|
||||
shape.yield %new_acc : !shape.size
|
||||
}
|
||||
return %num_elements : !shape.size
|
||||
}
|
||||
// CHECK-NEXT: [[SHAPE_C1:%.*]] = shape.const_size 1
|
||||
// CHECK-NEXT: [[C0:%.*]] = constant 0 : index
|
||||
// CHECK-NEXT: [[C1:%.*]] = constant 1 : index
|
||||
|
||||
// CHECK-NEXT: [[EXTENTS:%.*]] = "shape.to_extent_tensor"([[SHAPE]])
|
||||
// CHECK-NEXT: [[SIZE:%.*]] = dim [[EXTENTS]], [[C0]] : tensor<?xindex>
|
||||
|
||||
// CHECK-NEXT: [[RESULT:%.*]] = scf.for [[I:%.*]] = [[C0]] to [[SIZE]]
|
||||
// CHECK-SAME: step [[C1]] iter_args([[ACC:%.*]] = [[SHAPE_C1]])
|
||||
// CHECK-NEXT: [[EXTENT_INDEX:%.*]] = extract_element [[EXTENTS]]{{\[}}[[I]]]
|
||||
// CHECK-NEXT: [[EXTENT:%.*]] = shape.index_to_size [[EXTENT_INDEX]]
|
||||
// CHECK-NEXT: [[NEW_ACC:%.*]] = shape.mul [[ACC]], [[EXTENT]]
|
||||
// CHECK-NEXT: scf.yield [[NEW_ACC]] : !shape.size
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return [[RESULT]] : !shape.size
|
Loading…
Reference in New Issue