forked from OSchip/llvm-project
[MLIR] Lower shape.num_elements -> shape.reduce.
Differential Revision: https://reviews.llvm.org/D81279
This commit is contained in:
parent
50f68c1e33
commit
e80617df89
|
@ -36,6 +36,10 @@ This document describes the available MLIR passes and their contracts.
|
|||
|
||||
[include "QuantPasses.md"]
|
||||
|
||||
## `shape` Dialect Passes
|
||||
|
||||
[include "ShapePasses.md"]
|
||||
|
||||
## `spv` Dialect Passes
|
||||
|
||||
[include "SPIRVPasses.md"]
|
||||
|
|
|
@ -1 +1,2 @@
|
|||
add_subdirectory(IR)
|
||||
add_subdirectory(Transforms)
|
||||
|
|
|
@ -327,7 +327,7 @@ def Shape_ReduceOp : Shape_Op<"reduce",
|
|||
|
||||
let arguments = (ins Shape_ShapeType:$shape, Variadic<AnyType>:$initVals);
|
||||
let results = (outs Variadic<AnyType>:$result);
|
||||
let regions = (region SizedRegion<1>:$body);
|
||||
let regions = (region SizedRegion<1>:$region);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<"OpBuilder &builder, OperationState &result, "
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||
mlir_tablegen(Passes.h.inc -gen-pass-decls)
|
||||
add_public_tablegen_target(MLIRShapeTransformsIncGen)
|
||||
|
||||
add_mlir_doc(Passes -gen-pass-doc ShapePasses ./)
|
|
@ -0,0 +1,30 @@
|
|||
//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This header file defines prototypes that expose pass constructors in the
|
||||
// shape transformation library.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES_H_
|
||||
#define MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class Pass;
|
||||
|
||||
/// Creates an instance of the ShapeToShapeLowering pass that legalizes Shape
|
||||
/// dialect to be convertible to Standard. For example, `shape.num_elements` get
|
||||
/// transformed to `shape.reduce`, which can be lowered to SCF and Standard.
|
||||
std::unique_ptr<Pass> createShapeToShapeLowering();
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES_H_
|
|
@ -0,0 +1,19 @@
|
|||
//===-- Passes.td - ShapeOps pass definition file ----------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES
|
||||
#define MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def ShapeToShapeLowering : FunctionPass<"shape-to-shape-lowering"> {
|
||||
let summary = "Legalize Shape dialect to be convertible to Standard";
|
||||
let constructor = "mlir::createShapeToShapeLowering()";
|
||||
}
|
||||
|
||||
#endif // MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES
|
|
@ -37,6 +37,7 @@
|
|||
#include "mlir/Dialect/Quant/Passes.h"
|
||||
#include "mlir/Dialect/SCF/Passes.h"
|
||||
#include "mlir/Dialect/SPIRV/Passes.h"
|
||||
#include "mlir/Dialect/Shape/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
|
||||
#include "mlir/Transforms/LocationSnapshot.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
@ -94,6 +95,10 @@ inline void registerAllPasses() {
|
|||
// Standard
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "mlir/Dialect/StandardOps/Transforms/Passes.h.inc"
|
||||
|
||||
// Shape
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "mlir/Dialect/Shape/Transforms/Passes.h.inc"
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
|
|
|
@ -14,3 +14,5 @@ add_mlir_dialect_library(MLIRShape
|
|||
MLIRIR
|
||||
MLIRSideEffectInterfaces
|
||||
)
|
||||
|
||||
add_subdirectory(Transforms)
|
||||
|
|
|
@ -473,7 +473,7 @@ void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
|
|||
|
||||
static LogicalResult verify(ReduceOp op) {
|
||||
// Verify block arg types.
|
||||
Block &block = op.body().front();
|
||||
Block &block = op.region().front();
|
||||
|
||||
auto blockArgsCount = op.initVals().size() + 2;
|
||||
if (block.getNumArguments() != blockArgsCount)
|
||||
|
@ -529,7 +529,7 @@ static void print(OpAsmPrinter &p, ReduceOp op) {
|
|||
p << op.getOperationName() << '(' << op.shape() << ", " << op.initVals()
|
||||
<< ") ";
|
||||
p.printOptionalArrowTypeList(op.getResultTypes());
|
||||
p.printRegion(op.body());
|
||||
p.printRegion(op.region());
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
add_mlir_dialect_library(MLIRShapeOpsTransforms
|
||||
ShapeToShapeLowering.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ShapeOps/Transforms
|
||||
|
||||
DEPENDS
|
||||
MLIRShapeTransformsIncGen
|
||||
)
|
||||
|
||||
target_link_libraries(MLIRShapeOpsTransforms
|
||||
PUBLIC
|
||||
MLIRIR
|
||||
MLIRPass
|
||||
MLIRShape
|
||||
MLIRSupport
|
||||
)
|
|
@ -0,0 +1,21 @@
|
|||
//===- PassDetail.h - Shape Pass class details ------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef DIALECT_SHAPE_TRANSFORMS_PASSDETAIL_H_
|
||||
#define DIALECT_SHAPE_TRANSFORMS_PASSDETAIL_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "mlir/Dialect/Shape/Transforms/Passes.h.inc"
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // DIALECT_SHAPE_TRANSFORMS_PASSDETAIL_H_
|
|
@ -0,0 +1,69 @@
|
|||
//===- ShapeToShapeLowering.cpp - Prepare for lowering to Standard --------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "PassDetail.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/Dialect/Shape/Transforms/Passes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::shape;
|
||||
|
||||
namespace {
|
||||
/// Converts `shape.num_elements` to `shape.reduce`.
|
||||
struct NumElementsOpConverter : public OpRewritePattern<NumElementsOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(NumElementsOp op,
|
||||
PatternRewriter &rewriter) const final;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
LogicalResult
|
||||
NumElementsOpConverter::matchAndRewrite(NumElementsOp op,
|
||||
PatternRewriter &rewriter) const {
|
||||
auto loc = op.getLoc();
|
||||
Value init = rewriter.create<ConstSizeOp>(loc, rewriter.getIndexAttr(1));
|
||||
ReduceOp reduce = rewriter.create<ReduceOp>(loc, op.shape(), init);
|
||||
|
||||
// Generate reduce operator.
|
||||
Block *body = reduce.getBody();
|
||||
OpBuilder b = OpBuilder::atBlockEnd(body);
|
||||
Value product =
|
||||
b.create<MulOp>(loc, body->getArgument(1), body->getArgument(2));
|
||||
b.create<YieldOp>(loc, product);
|
||||
|
||||
rewriter.replaceOp(op, reduce.result());
|
||||
return success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct ShapeToShapeLowering
|
||||
: public ShapeToShapeLoweringBase<ShapeToShapeLowering> {
|
||||
void runOnFunction() override;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void ShapeToShapeLowering::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
patterns.insert<NumElementsOpConverter>(&getContext());
|
||||
|
||||
ConversionTarget target(getContext());
|
||||
target.addLegalDialect<ShapeDialect>();
|
||||
target.addIllegalOp<NumElementsOp>();
|
||||
if (failed(mlir::applyPartialConversion(getFunction(), target, patterns)))
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> mlir::createShapeToShapeLowering() {
|
||||
return std::make_unique<ShapeToShapeLowering>();
|
||||
}
|
|
@ -0,0 +1,16 @@
|
|||
// RUN: mlir-opt -shape-to-shape-lowering -split-input-file %s | FileCheck %s --dump-input-on-failure
|
||||
|
||||
// CHECK-LABEL: func @num_elements_to_reduce(
|
||||
// CHECK-SAME: [[ARG:%.*]]: !shape.shape) -> [[SIZE_TY:!.*]] {
|
||||
func @num_elements_to_reduce(%shape : !shape.shape) -> !shape.size {
|
||||
%num_elements = shape.num_elements %shape
|
||||
return %num_elements : !shape.size
|
||||
}
|
||||
// CHECK: [[C1:%.*]] = shape.const_size 1
|
||||
// CHECK: [[NUM_ELEMENTS:%.*]] = shape.reduce([[ARG]], [[C1]]) -> [[SIZE_TY]]
|
||||
// CHECK: ^bb0({{.*}}: index, [[DIM:%.*]]: [[SIZE_TY]], [[ACC:%.*]]: [[SIZE_TY]]
|
||||
// CHECK: [[NEW_ACC:%.*]] = shape.mul [[DIM]], [[ACC]]
|
||||
// CHECK: shape.yield [[NEW_ACC]] : [[SIZE_TY]]
|
||||
// CHECK: }
|
||||
// CHECK: return [[NUM_ELEMENTS]] : [[SIZE_TY]]
|
||||
|
Loading…
Reference in New Issue