[mlir][shape] Start a pass that lowers shape constraints.

This pass converts shape.cstr_* ops to eager (side-effecting)
error-handling code. After that conversion is done, the witnesses are
trivially satisfied and are replaced with `shape.const_witness true`.

Differential Revision: https://reviews.llvm.org/D87941
This commit is contained in:
Sean Silva 2020-09-18 13:55:52 -07:00
parent c4bacc3c9b
commit 9ed1e5873c
5 changed files with 209 additions and 0 deletions

View File

@ -242,6 +242,21 @@ def ConvertShapeToStandard : Pass<"convert-shape-to-std", "ModuleOp"> {
let dependentDialects = ["StandardOpsDialect", "scf::SCFDialect"];
}
def ConvertShapeConstraints: Pass<"convert-shape-constraints", "FuncOp"> {
let summary = "Convert shape constraint operations to the standard dialect";
let description = [{
This pass eliminates shape constraints from the program, converting them to
eager (side-effecting) error handling code.
This pass is separate from the regular convert-shape-to-standard, despite
converting between the same dialects, because converting shape constraints
can happen at a different part of the program than general shape
computation lowering.
}];
let constructor = "mlir::createConvertShapeConstraintsPass()";
let dependentDialects = ["StandardOpsDialect", "scf::SCFDialect"];
}
//===----------------------------------------------------------------------===//
// SPIRVToLLVM
//===----------------------------------------------------------------------===//

View File

@ -13,6 +13,7 @@
namespace mlir {
class FuncOp;
class MLIRContext;
class ModuleOp;
template <typename T>
@ -24,6 +25,11 @@ void populateShapeToStandardConversionPatterns(
std::unique_ptr<OperationPass<ModuleOp>> createConvertShapeToStandardPass();
void populateConvertShapeConstraintsConversionPatterns(
OwningRewritePatternList &patterns, MLIRContext *ctx);
std::unique_ptr<OperationPass<FuncOp>> createConvertShapeConstraintsPass();
} // namespace mlir
#endif // MLIR_CONVERSION_SHAPETOSTANDARD_SHAPETOSTANDARD_H_

View File

@ -1,4 +1,5 @@
add_mlir_conversion_library(MLIRShapeToStandard
ConvertShapeConstraints.cpp
ShapeToStandard.cpp
ADDITIONAL_HEADER_DIRS

View File

@ -0,0 +1,143 @@
//===- ConvertShapeConstraints.cpp - Conversion of shape constraints ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
#include "../PassDetail.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
using namespace mlir;
namespace {
class ConvertCstrBroadcastableOp
: public OpRewritePattern<shape::CstrBroadcastableOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op,
PatternRewriter &rewriter) const override {
if (op.getType().isa<shape::ShapeType>() ||
op.lhs().getType().isa<shape::ShapeType>() ||
op.rhs().getType().isa<shape::ShapeType>()) {
return rewriter.notifyMatchFailure(
op, "cannot convert error-propagating shapes");
}
auto loc = op.getLoc();
Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
Value one = rewriter.create<ConstantIndexOp>(loc, 1);
// Find smaller and greater rank and extent tensor.
Value lhsRank = rewriter.create<DimOp>(loc, op.lhs(), zero);
Value rhsRank = rewriter.create<DimOp>(loc, op.rhs(), zero);
Value lhsSmaller =
rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank);
Type indexTy = rewriter.getIndexType();
Type extentTensorTy = op.lhs().getType();
auto ifOp = rewriter.create<scf::IfOp>(
loc, TypeRange{indexTy, extentTensorTy, indexTy, extentTensorTy},
lhsSmaller,
[&](OpBuilder &b, Location loc) {
b.create<scf::YieldOp>(
loc, ValueRange{lhsRank, op.lhs(), rhsRank, op.rhs()});
},
[&](OpBuilder &b, Location loc) {
b.create<scf::YieldOp>(
loc, ValueRange{rhsRank, op.rhs(), lhsRank, op.lhs()});
});
Value lesserRank = ifOp.getResult(0);
Value lesserRankOperand = ifOp.getResult(1);
Value greaterRank = ifOp.getResult(2);
Value greaterRankOperand = ifOp.getResult(3);
Value rankDiff =
rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank);
// Generate code to compare the shapes extent by extent, and emit errors for
// non-broadcast-compatible shapes.
// Two extents are broadcast-compatible if
// 1. they are both equal, or
// 2. at least one of them is 1.
rewriter.create<scf::ForOp>(
loc, rankDiff, greaterRank, one, llvm::None,
[&](OpBuilder &b, Location loc, Value iv, ValueRange) {
Value greaterRankOperandExtent = b.create<ExtractElementOp>(
loc, greaterRankOperand, ValueRange{iv});
Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff);
Value lesserRankOperandExtent = b.create<ExtractElementOp>(
loc, lesserRankOperand, ValueRange{ivShifted});
Value greaterRankOperandExtentIsOne = b.create<CmpIOp>(
loc, CmpIPredicate::eq, greaterRankOperandExtent, one);
Value lesserRankOperandExtentIsOne = b.create<CmpIOp>(
loc, CmpIPredicate::eq, lesserRankOperandExtent, one);
Value extentsAgree =
b.create<CmpIOp>(loc, CmpIPredicate::eq, greaterRankOperandExtent,
lesserRankOperandExtent);
auto broadcastIsValid =
b.create<OrOp>(loc, b.getI1Type(), extentsAgree,
b.create<OrOp>(loc, greaterRankOperandExtentIsOne,
lesserRankOperandExtentIsOne));
b.create<AssertOp>(loc, broadcastIsValid, "invalid broadcast");
b.create<scf::YieldOp>(loc);
});
rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true);
return success();
}
};
} // namespace
namespace {
class ConvertCstrRequireOp : public OpRewritePattern<shape::CstrRequireOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(shape::CstrRequireOp op,
PatternRewriter &rewriter) const override {
rewriter.create<AssertOp>(op.getLoc(), op.pred(), op.msgAttr());
rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true);
return success();
}
};
} // namespace
void mlir::populateConvertShapeConstraintsConversionPatterns(
OwningRewritePatternList &patterns, MLIRContext *ctx) {
patterns.insert<ConvertCstrBroadcastableOp>(ctx);
patterns.insert<ConvertCstrRequireOp>(ctx);
}
namespace {
// This pass eliminates shape constraints from the program, converting them to
// eager (side-effecting) error handling code. After eager error handling code
// is emitted, witnesses are satisfied, so they are replace with
// `shape.const_witness true`.
class ConvertShapeConstraints
: public ConvertShapeConstraintsBase<ConvertShapeConstraints> {
void runOnOperation() {
auto func = getOperation();
auto *context = &getContext();
OwningRewritePatternList patterns;
populateConvertShapeConstraintsConversionPatterns(patterns, context);
if (failed(applyPatternsAndFoldGreedily(func, patterns)))
return signalPassFailure();
}
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
mlir::createConvertShapeConstraintsPass() {
return std::make_unique<ConvertShapeConstraints>();
}

View File

@ -0,0 +1,44 @@
// RUN: mlir-opt -convert-shape-constraints <%s | FileCheck %s
// There's not very much useful to check here other than pasting the output.
// CHECK-LABEL: func @cstr_broadcastable(
// CHECK-SAME: %[[LHS:.*]]: tensor<?xindex>,
// CHECK-SAME: %[[RHS:.*]]: tensor<?xindex>) -> !shape.witness {
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[RET:.*]] = shape.const_witness true
// CHECK: %[[LHSRANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<?xindex>
// CHECK: %[[RHSRANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<?xindex>
// CHECK: %[[LESSEQUAL:.*]] = cmpi "ule", %[[LHSRANK]], %[[RHSRANK]] : index
// CHECK: %[[IFRESULTS:.*]]:4 = scf.if %[[LESSEQUAL]] -> (index, tensor<?xindex>, index, tensor<?xindex>) {
// CHECK: scf.yield %[[LHSRANK]], %[[LHS]], %[[RHSRANK]], %[[RHS]] : index, tensor<?xindex>, index, tensor<?xindex>
// CHECK: } else {
// CHECK: scf.yield %[[RHSRANK]], %[[RHS]], %[[LHSRANK]], %[[LHS]] : index, tensor<?xindex>, index, tensor<?xindex>
// CHECK: }
// CHECK: %[[RANKDIFF:.*]] = subi %[[IFRESULTS:.*]]#2, %[[IFRESULTS]]#0 : index
// CHECK: scf.for %[[IV:.*]] = %[[RANKDIFF]] to %[[IFRESULTS]]#2 step %[[C1]] {
// CHECK: %[[GREATERRANKOPERANDEXTENT:.*]] = extract_element %[[IFRESULTS]]#3{{\[}}%[[IV]]] : tensor<?xindex>
// CHECK: %[[IVSHIFTED:.*]] = subi %[[IV]], %[[RANKDIFF]] : index
// CHECK: %[[LESSERRANKOPERANDEXTENT:.*]] = extract_element %[[IFRESULTS]]#1{{\[}}%[[IVSHIFTED]]] : tensor<?xindex>
// CHECK: %[[GREATERRANKOPERANDEXTENTISONE:.*]] = cmpi "eq", %[[GREATERRANKOPERANDEXTENT]], %[[C1]] : index
// CHECK: %[[LESSERRANKOPERANDEXTENTISONE:.*]] = cmpi "eq", %[[LESSERRANKOPERANDEXTENT]], %[[C1]] : index
// CHECK: %[[EXTENTSAGREE:.*]] = cmpi "eq", %[[GREATERRANKOPERANDEXTENT]], %[[LESSERRANKOPERANDEXTENT]] : index
// CHECK: %[[OR_TMP:.*]] = or %[[GREATERRANKOPERANDEXTENTISONE]], %[[LESSERRANKOPERANDEXTENTISONE]] : i1
// CHECK: %[[BROADCASTISVALID:.*]] = or %[[EXTENTSAGREE]], %[[OR_TMP]] : i1
// CHECK: assert %[[BROADCASTISVALID]], "invalid broadcast"
// CHECK: }
// CHECK: return %[[RET]] : !shape.witness
// CHECK: }
func @cstr_broadcastable(%arg0: tensor<?xindex>, %arg1: tensor<?xindex>) -> !shape.witness {
%witness = shape.cstr_broadcastable %arg0, %arg1 : tensor<?xindex>, tensor<?xindex>
return %witness : !shape.witness
}
// CHECK-LABEL: func @cstr_require
func @cstr_require(%arg0: i1) -> !shape.witness {
// CHECK: %[[RET:.*]] = shape.const_witness true
// CHECK: assert %arg0, "msg"
// CHECK: return %[[RET]]
%witness = shape.cstr_require %arg0, "msg"
return %witness : !shape.witness
}