[mlir][shape] Add a pattern to rewrite `shape.reduce` as `scf.for`.

Differential Revision: https://reviews.llvm.org/D81694
This commit is contained in:
Alexander Belyaev 2020-06-15 17:14:21 +02:00
parent e1741e34e0
commit 3813f24e97
7 changed files with 184 additions and 0 deletions

View File

@ -206,6 +206,15 @@ def ConvertShapeToStandard : Pass<"convert-shape-to-std", "ModuleOp"> {
let constructor = "mlir::createConvertShapeToStandardPass()";
}
//===----------------------------------------------------------------------===//
// ShapeToSCF
//===----------------------------------------------------------------------===//
def ConvertShapeToSCF : FunctionPass<"convert-shape-to-scf"> {
let summary = "Convert operations from the shape dialect to the SCF dialect";
let constructor = "mlir::createConvertShapeToSCFPass()";
}
//===----------------------------------------------------------------------===//
// SPIRVToLLVM
//===----------------------------------------------------------------------===//

View File

@ -0,0 +1,27 @@
//===- ShapeToSCF.h - Conversion utils from Shape to SCF dialect ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_SHAPETOSCF_SHAPETOSCF_H_
#define MLIR_CONVERSION_SHAPETOSCF_SHAPETOSCF_H_
#include <memory>
namespace mlir {
class MLIRContext;
class FunctionPass;
class OwningRewritePatternList;
void populateShapeToSCFConversionPatterns(OwningRewritePatternList &patterns,
MLIRContext *ctx);
std::unique_ptr<FunctionPass> createConvertShapeToSCFPass();
} // namespace mlir
#endif // MLIR_CONVERSION_SHAPETOSCF_SHAPETOSCF_H_

View File

@ -26,6 +26,7 @@
#include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
#include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.h"
#include "mlir/Conversion/ShapeToSCF/ShapeToSCF.h"
#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h"

View File

@ -10,6 +10,7 @@ add_subdirectory(LinalgToSPIRV)
add_subdirectory(LinalgToStandard)
add_subdirectory(SCFToGPU)
add_subdirectory(SCFToStandard)
add_subdirectory(ShapeToSCF)
add_subdirectory(ShapeToStandard)
add_subdirectory(SPIRVToLLVM)
add_subdirectory(StandardToLLVM)

View File

@ -0,0 +1,19 @@
add_mlir_conversion_library(MLIRShapeToSCF
ShapeToSCF.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ShapeToSCF
DEPENDS
MLIRConversionPassIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRIR
MLIRShape
MLIRPass
MLIRSCF
MLIRTransforms
)

View File

@ -0,0 +1,99 @@
//===- ShapeToSCF.cpp - conversion from Shape to SCF dialect --------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/ShapeToSCF/ShapeToSCF.h"
#include "../PassDetail.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
using namespace mlir::shape;
namespace {
/// Converts `shape.reduce` to `scf.for`.
struct ReduceOpConverter : public OpRewritePattern<ReduceOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ReduceOp op,
PatternRewriter &rewriter) const final;
};
} // namespace
LogicalResult
ReduceOpConverter::matchAndRewrite(ReduceOp reduceOp,
PatternRewriter &rewriter) const {
auto loc = reduceOp.getLoc();
Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
Value one = rewriter.create<ConstantIndexOp>(loc, 1);
Value extentTensor = rewriter.create<ToExtentTensorOp>(
loc,
RankedTensorType::get({ShapedType::kDynamicSize},
rewriter.getIndexType()),
reduceOp.shape());
Value size =
rewriter.create<DimOp>(loc, rewriter.getIndexType(), extentTensor, zero);
auto loop = rewriter.create<scf::ForOp>(
loc, zero, size, one, reduceOp.initVals(),
[&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) {
Value indexExtent = b.create<ExtractElementOp>(loc, extentTensor, iv);
Value sizeExtent = b.create<IndexToSizeOp>(loc, indexExtent);
SmallVector<Value, 2> mapped_values{iv, sizeExtent};
mapped_values.append(args.begin(), args.end());
BlockAndValueMapping mapping;
Block *reduceBody = reduceOp.getBody();
mapping.map(reduceBody->getArguments(), mapped_values);
for (auto &nested : reduceBody->without_terminator())
b.clone(nested, mapping);
SmallVector<Value, 2> mappedResults;
for (auto result : reduceBody->getTerminator()->getOperands())
mappedResults.push_back(mapping.lookup(result));
b.create<scf::YieldOp>(loc, mappedResults);
});
rewriter.replaceOp(reduceOp, loop.getResults());
return success();
}
namespace {
struct ConvertShapeToSCFPass
: public ConvertShapeToSCFBase<ConvertShapeToSCFPass> {
void runOnFunction() override;
};
} // namespace
void ConvertShapeToSCFPass::runOnFunction() {
MLIRContext &ctx = getContext();
OwningRewritePatternList patterns;
populateShapeToSCFConversionPatterns(patterns, &ctx);
ConversionTarget target(getContext());
target.addLegalDialect<ShapeDialect, scf::SCFDialect, StandardOpsDialect>();
target.addIllegalOp<ReduceOp>();
if (failed(mlir::applyPartialConversion(getFunction(), target, patterns)))
signalPassFailure();
}
void mlir::populateShapeToSCFConversionPatterns(
OwningRewritePatternList &patterns, MLIRContext *ctx) {
patterns.insert<ReduceOpConverter>(ctx);
}
std::unique_ptr<FunctionPass> mlir::createConvertShapeToSCFPass() {
return std::make_unique<ConvertShapeToSCFPass>();
}

View File

@ -0,0 +1,28 @@
// RUN: mlir-opt -convert-shape-to-scf -split-input-file %s | FileCheck %s
// CHECK-LABEL: shape_reduce
// CHECK-SAME: [[SHAPE:%.*]]: !shape.shape) -> !shape.size {
func @shape_reduce(%shape : !shape.shape) -> !shape.size {
%init = shape.const_size 1
%num_elements = shape.reduce(%shape, %init) -> !shape.size {
^bb0(%index: index, %dim: !shape.size, %acc: !shape.size):
%new_acc = shape.mul %acc, %dim
shape.yield %new_acc : !shape.size
}
return %num_elements : !shape.size
}
// CHECK-NEXT: [[SHAPE_C1:%.*]] = shape.const_size 1
// CHECK-NEXT: [[C0:%.*]] = constant 0 : index
// CHECK-NEXT: [[C1:%.*]] = constant 1 : index
// CHECK-NEXT: [[EXTENTS:%.*]] = "shape.to_extent_tensor"([[SHAPE]])
// CHECK-NEXT: [[SIZE:%.*]] = dim [[EXTENTS]], [[C0]] : tensor<?xindex>
// CHECK-NEXT: [[RESULT:%.*]] = scf.for [[I:%.*]] = [[C0]] to [[SIZE]]
// CHECK-SAME: step [[C1]] iter_args([[ACC:%.*]] = [[SHAPE_C1]])
// CHECK-NEXT: [[EXTENT_INDEX:%.*]] = extract_element [[EXTENTS]]{{\[}}[[I]]]
// CHECK-NEXT: [[EXTENT:%.*]] = shape.index_to_size [[EXTENT_INDEX]]
// CHECK-NEXT: [[NEW_ACC:%.*]] = shape.mul [[ACC]], [[EXTENT]]
// CHECK-NEXT: scf.yield [[NEW_ACC]] : !shape.size
// CHECK-NEXT: }
// CHECK-NEXT: return [[RESULT]] : !shape.size