forked from OSchip/llvm-project
[mlir][shape] Start a pass that lowers shape constraints.
This pass converts shape.cstr_* ops to eager (side-effecting) error-handling code. After that conversion is done, the witnesses are trivially satisfied and are replaced with `shape.const_witness true`. Differential Revision: https://reviews.llvm.org/D87941
This commit is contained in:
parent
c4bacc3c9b
commit
9ed1e5873c
|
@ -242,6 +242,21 @@ def ConvertShapeToStandard : Pass<"convert-shape-to-std", "ModuleOp"> {
|
|||
let dependentDialects = ["StandardOpsDialect", "scf::SCFDialect"];
|
||||
}
|
||||
|
||||
def ConvertShapeConstraints: Pass<"convert-shape-constraints", "FuncOp"> {
|
||||
let summary = "Convert shape constraint operations to the standard dialect";
|
||||
let description = [{
|
||||
This pass eliminates shape constraints from the program, converting them to
|
||||
eager (side-effecting) error handling code.
|
||||
|
||||
This pass is separate from the regular convert-shape-to-standard, despite
|
||||
converting between the same dialects, because converting shape constraints
|
||||
can happen at a different part of the program than general shape
|
||||
computation lowering.
|
||||
}];
|
||||
let constructor = "mlir::createConvertShapeConstraintsPass()";
|
||||
let dependentDialects = ["StandardOpsDialect", "scf::SCFDialect"];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SPIRVToLLVM
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
|
||||
namespace mlir {
|
||||
|
||||
class FuncOp;
|
||||
class MLIRContext;
|
||||
class ModuleOp;
|
||||
template <typename T>
|
||||
|
@ -24,6 +25,11 @@ void populateShapeToStandardConversionPatterns(
|
|||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertShapeToStandardPass();
|
||||
|
||||
void populateConvertShapeConstraintsConversionPatterns(
|
||||
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createConvertShapeConstraintsPass();
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_CONVERSION_SHAPETOSTANDARD_SHAPETOSTANDARD_H_
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
add_mlir_conversion_library(MLIRShapeToStandard
|
||||
ConvertShapeConstraints.cpp
|
||||
ShapeToStandard.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
|
|
|
@ -0,0 +1,143 @@
|
|||
//===- ConvertShapeConstraints.cpp - Conversion of shape constraints ------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
|
||||
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassRegistry.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
class ConvertCstrBroadcastableOp
|
||||
: public OpRewritePattern<shape::CstrBroadcastableOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (op.getType().isa<shape::ShapeType>() ||
|
||||
op.lhs().getType().isa<shape::ShapeType>() ||
|
||||
op.rhs().getType().isa<shape::ShapeType>()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "cannot convert error-propagating shapes");
|
||||
}
|
||||
|
||||
auto loc = op.getLoc();
|
||||
Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
|
||||
Value one = rewriter.create<ConstantIndexOp>(loc, 1);
|
||||
|
||||
// Find smaller and greater rank and extent tensor.
|
||||
Value lhsRank = rewriter.create<DimOp>(loc, op.lhs(), zero);
|
||||
Value rhsRank = rewriter.create<DimOp>(loc, op.rhs(), zero);
|
||||
Value lhsSmaller =
|
||||
rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank);
|
||||
Type indexTy = rewriter.getIndexType();
|
||||
Type extentTensorTy = op.lhs().getType();
|
||||
auto ifOp = rewriter.create<scf::IfOp>(
|
||||
loc, TypeRange{indexTy, extentTensorTy, indexTy, extentTensorTy},
|
||||
lhsSmaller,
|
||||
[&](OpBuilder &b, Location loc) {
|
||||
b.create<scf::YieldOp>(
|
||||
loc, ValueRange{lhsRank, op.lhs(), rhsRank, op.rhs()});
|
||||
},
|
||||
[&](OpBuilder &b, Location loc) {
|
||||
b.create<scf::YieldOp>(
|
||||
loc, ValueRange{rhsRank, op.rhs(), lhsRank, op.lhs()});
|
||||
});
|
||||
Value lesserRank = ifOp.getResult(0);
|
||||
Value lesserRankOperand = ifOp.getResult(1);
|
||||
Value greaterRank = ifOp.getResult(2);
|
||||
Value greaterRankOperand = ifOp.getResult(3);
|
||||
|
||||
Value rankDiff =
|
||||
rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank);
|
||||
|
||||
// Generate code to compare the shapes extent by extent, and emit errors for
|
||||
// non-broadcast-compatible shapes.
|
||||
// Two extents are broadcast-compatible if
|
||||
// 1. they are both equal, or
|
||||
// 2. at least one of them is 1.
|
||||
|
||||
rewriter.create<scf::ForOp>(
|
||||
loc, rankDiff, greaterRank, one, llvm::None,
|
||||
[&](OpBuilder &b, Location loc, Value iv, ValueRange) {
|
||||
Value greaterRankOperandExtent = b.create<ExtractElementOp>(
|
||||
loc, greaterRankOperand, ValueRange{iv});
|
||||
Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff);
|
||||
Value lesserRankOperandExtent = b.create<ExtractElementOp>(
|
||||
loc, lesserRankOperand, ValueRange{ivShifted});
|
||||
|
||||
Value greaterRankOperandExtentIsOne = b.create<CmpIOp>(
|
||||
loc, CmpIPredicate::eq, greaterRankOperandExtent, one);
|
||||
Value lesserRankOperandExtentIsOne = b.create<CmpIOp>(
|
||||
loc, CmpIPredicate::eq, lesserRankOperandExtent, one);
|
||||
Value extentsAgree =
|
||||
b.create<CmpIOp>(loc, CmpIPredicate::eq, greaterRankOperandExtent,
|
||||
lesserRankOperandExtent);
|
||||
auto broadcastIsValid =
|
||||
b.create<OrOp>(loc, b.getI1Type(), extentsAgree,
|
||||
b.create<OrOp>(loc, greaterRankOperandExtentIsOne,
|
||||
lesserRankOperandExtentIsOne));
|
||||
b.create<AssertOp>(loc, broadcastIsValid, "invalid broadcast");
|
||||
b.create<scf::YieldOp>(loc);
|
||||
});
|
||||
|
||||
rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertCstrRequireOp : public OpRewritePattern<shape::CstrRequireOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(shape::CstrRequireOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
rewriter.create<AssertOp>(op.getLoc(), op.pred(), op.msgAttr());
|
||||
rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void mlir::populateConvertShapeConstraintsConversionPatterns(
|
||||
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
||||
patterns.insert<ConvertCstrBroadcastableOp>(ctx);
|
||||
patterns.insert<ConvertCstrRequireOp>(ctx);
|
||||
}
|
||||
|
||||
namespace {
|
||||
// This pass eliminates shape constraints from the program, converting them to
|
||||
// eager (side-effecting) error handling code. After eager error handling code
|
||||
// is emitted, witnesses are satisfied, so they are replace with
|
||||
// `shape.const_witness true`.
|
||||
class ConvertShapeConstraints
|
||||
: public ConvertShapeConstraintsBase<ConvertShapeConstraints> {
|
||||
void runOnOperation() {
|
||||
auto func = getOperation();
|
||||
auto *context = &getContext();
|
||||
|
||||
OwningRewritePatternList patterns;
|
||||
populateConvertShapeConstraintsConversionPatterns(patterns, context);
|
||||
|
||||
if (failed(applyPatternsAndFoldGreedily(func, patterns)))
|
||||
return signalPassFailure();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
mlir::createConvertShapeConstraintsPass() {
|
||||
return std::make_unique<ConvertShapeConstraints>();
|
||||
}
|
|
@ -0,0 +1,44 @@
|
|||
// RUN: mlir-opt -convert-shape-constraints <%s | FileCheck %s
|
||||
|
||||
// There's not very much useful to check here other than pasting the output.
|
||||
// CHECK-LABEL: func @cstr_broadcastable(
|
||||
// CHECK-SAME: %[[LHS:.*]]: tensor<?xindex>,
|
||||
// CHECK-SAME: %[[RHS:.*]]: tensor<?xindex>) -> !shape.witness {
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[RET:.*]] = shape.const_witness true
|
||||
// CHECK: %[[LHSRANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<?xindex>
|
||||
// CHECK: %[[RHSRANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<?xindex>
|
||||
// CHECK: %[[LESSEQUAL:.*]] = cmpi "ule", %[[LHSRANK]], %[[RHSRANK]] : index
|
||||
// CHECK: %[[IFRESULTS:.*]]:4 = scf.if %[[LESSEQUAL]] -> (index, tensor<?xindex>, index, tensor<?xindex>) {
|
||||
// CHECK: scf.yield %[[LHSRANK]], %[[LHS]], %[[RHSRANK]], %[[RHS]] : index, tensor<?xindex>, index, tensor<?xindex>
|
||||
// CHECK: } else {
|
||||
// CHECK: scf.yield %[[RHSRANK]], %[[RHS]], %[[LHSRANK]], %[[LHS]] : index, tensor<?xindex>, index, tensor<?xindex>
|
||||
// CHECK: }
|
||||
// CHECK: %[[RANKDIFF:.*]] = subi %[[IFRESULTS:.*]]#2, %[[IFRESULTS]]#0 : index
|
||||
// CHECK: scf.for %[[IV:.*]] = %[[RANKDIFF]] to %[[IFRESULTS]]#2 step %[[C1]] {
|
||||
// CHECK: %[[GREATERRANKOPERANDEXTENT:.*]] = extract_element %[[IFRESULTS]]#3{{\[}}%[[IV]]] : tensor<?xindex>
|
||||
// CHECK: %[[IVSHIFTED:.*]] = subi %[[IV]], %[[RANKDIFF]] : index
|
||||
// CHECK: %[[LESSERRANKOPERANDEXTENT:.*]] = extract_element %[[IFRESULTS]]#1{{\[}}%[[IVSHIFTED]]] : tensor<?xindex>
|
||||
// CHECK: %[[GREATERRANKOPERANDEXTENTISONE:.*]] = cmpi "eq", %[[GREATERRANKOPERANDEXTENT]], %[[C1]] : index
|
||||
// CHECK: %[[LESSERRANKOPERANDEXTENTISONE:.*]] = cmpi "eq", %[[LESSERRANKOPERANDEXTENT]], %[[C1]] : index
|
||||
// CHECK: %[[EXTENTSAGREE:.*]] = cmpi "eq", %[[GREATERRANKOPERANDEXTENT]], %[[LESSERRANKOPERANDEXTENT]] : index
|
||||
// CHECK: %[[OR_TMP:.*]] = or %[[GREATERRANKOPERANDEXTENTISONE]], %[[LESSERRANKOPERANDEXTENTISONE]] : i1
|
||||
// CHECK: %[[BROADCASTISVALID:.*]] = or %[[EXTENTSAGREE]], %[[OR_TMP]] : i1
|
||||
// CHECK: assert %[[BROADCASTISVALID]], "invalid broadcast"
|
||||
// CHECK: }
|
||||
// CHECK: return %[[RET]] : !shape.witness
|
||||
// CHECK: }
|
||||
func @cstr_broadcastable(%arg0: tensor<?xindex>, %arg1: tensor<?xindex>) -> !shape.witness {
|
||||
%witness = shape.cstr_broadcastable %arg0, %arg1 : tensor<?xindex>, tensor<?xindex>
|
||||
return %witness : !shape.witness
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @cstr_require
|
||||
func @cstr_require(%arg0: i1) -> !shape.witness {
|
||||
// CHECK: %[[RET:.*]] = shape.const_witness true
|
||||
// CHECK: assert %arg0, "msg"
|
||||
// CHECK: return %[[RET]]
|
||||
%witness = shape.cstr_require %arg0, "msg"
|
||||
return %witness : !shape.witness
|
||||
}
|
Loading…
Reference in New Issue