Add support for inlining toy call operations.

The GenericCallOp needed to have the CallOpInterface to be picked up by the inliner. This also adds a CastOp to perform shape casts that are generated during inlining. The casts generated by the inliner will be folded away after shape inference.

PiperOrigin-RevId: 275150438
This commit is contained in:
River Riddle 2019-10-16 17:32:30 -07:00 committed by A. Unique TensorFlower
parent 7053a30f4b
commit 7045471913
7 changed files with 97 additions and 7 deletions

View File

@ -21,6 +21,7 @@ add_toy_chapter(toyc-ch4
add_dependencies(toyc-ch4 ToyCh4OpsIncGen)
add_dependencies(toyc-ch4 ToyCh4ShapeInferenceInterfaceIncGen)
add_dependencies(toyc-ch4 ToyCh4CombineIncGen)
add_dependencies(toyc-ch4 MLIRCallOpInterfacesIncGen)
include_directories(include/)
include_directories(${CMAKE_CURRENT_BINARY_DIR})
include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/)

View File

@ -23,6 +23,11 @@
#else
#define TOY_OPS
#ifdef MLIR_CALLINTERFACES
#else
include "mlir/Analysis/CallInterfaces.td"
#endif // MLIR_CALLINTERFACES
#ifdef SHAPE_INFERENCE_INTERFACE
#else
include "toy/ShapeInferenceInterface.td"
@ -111,7 +116,27 @@ def AddOp : Toy_Op<"add",
>];
}
def GenericCallOp : Toy_Op<"generic_call"> {
def CastOp : Toy_Op<"cast",
[DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, NoSideEffect,
SameOperandsAndResultShape]> {
let summary = "shape cast operation";
let description = [{
The "cast" operation converts a tensor from one type to an equivalent type
without changing any data elements. The source and destination types
must both be tensor types with the same element type. If both are ranked
then the rank should be the same and static dimensions should match. The
operation is invalid if converting to a mismatching constant dimension.
}];
let arguments = (ins F64Tensor:$input);
let results = (outs F64Tensor:$output);
// Set the folder bit so that we can fold redundant cast operations.
let hasFolder = 1;
}
def GenericCallOp : Toy_Op<"generic_call",
[DeclareOpInterfaceMethods<CallOpInterface>]> {
let summary = "generic call operation";
let description = [{
Generic calls represent calls to a user defined function that needs to

View File

@ -64,6 +64,17 @@ struct ToyInlinerInterface : public DialectInlinerInterface {
for (const auto &it : llvm::enumerate(returnOp.getOperands()))
valuesToRepl[it.index()]->replaceAllUsesWith(it.value());
}
/// Attempts to materialize a conversion for a type mismatch between a call
/// from this dialect, and a callable region. This method should generate an
/// operation that takes 'input' as the only operand, and produces a single
/// result of 'resultType'. If a conversion can not be generated, nullptr
/// should be returned.
Operation *materializeCallConversion(OpBuilder &builder, Value *input,
Type resultType,
Location conversionLoc) const final {
return builder.create<CastOp>(conversionLoc, resultType, input);
}
};
//===----------------------------------------------------------------------===//
@ -94,7 +105,12 @@ static void buildConstantOp(mlir::Builder *builder, mlir::OperationState &state,
ConstantOp::build(builder, state, dataType, dataAttribute);
}
/// Verifier for constant operation.
/// Infer the output shape of the CastOp, this is required by the shape
/// inference interface.
void CastOp::inferShapes() { getResult()->setType(getOperand()->getType()); }
/// Verifier for the constant operation. This corresponds to the `::verify(...)`
/// in the op definition.
static mlir::LogicalResult verify(ConstantOp op) {
// If the return type of the constant is not an unranked tensor, the shape
// must match the shape of the attribute holding the data.
@ -139,6 +155,16 @@ static void buildGenericCallOp(mlir::Builder *builder,
state.addAttribute("callee", builder->getSymbolRefAttr(callee));
}
/// Return the callee of the generic call operation, this is required by the
/// call interface.
CallInterfaceCallable GenericCallOp::getCallableForCallee() {
return getAttrOfType<SymbolRefAttr>("callee");
}
/// Get the argument operands to the called function, this is required by the
/// call interface.
Operation::operand_range GenericCallOp::getArgOperands() { return inputs(); }
static void buildMulOp(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *lhs, mlir::Value *rhs) {
state.addTypes(builder->getTensorType(builder->getF64Type()));

View File

@ -80,8 +80,13 @@ public:
// Ask the operation to infer its output shapes.
LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n");
auto shapeOp = dyn_cast<ShapeInference>(op);
shapeOp.inferShapes();
if (auto shapeOp = dyn_cast<ShapeInference>(op)) {
shapeOp.inferShapes();
} else {
op->emitError("unable to infer shape of operation without shape "
"inference interface");
return signalPassFailure();
}
}
// If the operation worklist isn't empty, this indicates a failure.

View File

@ -32,6 +32,11 @@ namespace {
#include "ToyCombine.inc"
} // end anonymous namespace
/// Fold simple cast operations that return the same type as the input.
OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
return mlir::impl::foldCastOp(*this);
}
/// This is an example of a c++ rewrite pattern for the TransposeOp. It
/// optimizes the following scenario: transpose(transpose(x)) -> transpose(x)
struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {

View File

@ -122,9 +122,6 @@ int dumpMLIR() {
// Apply any generic pass manager command line options and run the pipeline.
applyPassManagerCLOptions(pm);
// Add a run of the canonicalizer to optimize the mlir module.
pm.addPass(mlir::createCanonicalizerPass());
// Inline all functions into main and then delete them.
pm.addPass(mlir::createInlinerPass());
pm.addPass(mlir::toy::createDeadFunctionEliminationPass());
@ -132,6 +129,7 @@ int dumpMLIR() {
// Now that there is only one function, we can infer the shapes of each of
// the operations.
pm.addPass(mlir::toy::createShapeInferencePass());
pm.addPass(mlir::createCanonicalizerPass());
if (mlir::failed(pm.run(*module)))
return 4;

View File

@ -0,0 +1,30 @@
// RUN: toyc-ch4 %s -emit=mlir -opt 2>&1 | FileCheck %s
// Check the result of inlining+shape inference on an input module.
func @multiply_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> {
%0 = "toy.transpose"(%arg1) : (tensor<*xf64>) -> tensor<*xf64>
%1 = "toy.mul"(%arg0, %0) : (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64>
"toy.return"(%1) : (tensor<*xf64>) -> ()
}
func @main() {
%0 = "toy.constant"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>
%1 = "toy.reshape"(%0) : (tensor<2x3xf64>) -> tensor<2x3xf64>
%2 = "toy.constant"() {value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>} : () -> tensor<6xf64>
%3 = "toy.reshape"(%2) : (tensor<6xf64>) -> tensor<2x3xf64>
%4 = "toy.generic_call"(%1, %3) {callee = @multiply_transpose} : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64>
%5 = "toy.generic_call"(%3, %1) {callee = @multiply_transpose} : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64>
"toy.print"(%5) : (tensor<*xf64>) -> ()
"toy.return"() : () -> ()
}
// CHECK-NOT: func @multiply_transpose
// CHECK-NOT: tensor<*xf64>
// CHECK-LABEL: func @main() {
// CHECK: [[VAL_0:%.*]] = "toy.constant"() {value = dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>
// CHECK: [[VAL_1:%.*]] = "toy.constant"() {value = dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>
// CHECK: [[VAL_2:%.*]] = "toy.transpose"([[VAL_0]]) : (tensor<2x3xf64>) -> tensor<3x2xf64>
// CHECK: [[VAL_3:%.*]] = "toy.mul"([[VAL_1]], [[VAL_2]]) : (tensor<2x3xf64>, tensor<3x2xf64>) -> tensor<2x2xf64>
// CHECK: "toy.print"([[VAL_3]]) : (tensor<2x2xf64>) -> ()
// CHECK: "toy.return"() : () -> ()