forked from OSchip/llvm-project
Add support for inlining toy call operations.
The GenericCallOp needed to have the CallOpInterface to be picked up by the inliner. This also adds a CastOp to perform shape casts that are generated during inlining. The casts generated by the inliner will be folded away after shape inference. PiperOrigin-RevId: 275150438
This commit is contained in:
parent
7053a30f4b
commit
7045471913
|
@ -21,6 +21,7 @@ add_toy_chapter(toyc-ch4
|
|||
add_dependencies(toyc-ch4 ToyCh4OpsIncGen)
|
||||
add_dependencies(toyc-ch4 ToyCh4ShapeInferenceInterfaceIncGen)
|
||||
add_dependencies(toyc-ch4 ToyCh4CombineIncGen)
|
||||
add_dependencies(toyc-ch4 MLIRCallOpInterfacesIncGen)
|
||||
include_directories(include/)
|
||||
include_directories(${CMAKE_CURRENT_BINARY_DIR})
|
||||
include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/)
|
||||
|
|
|
@ -23,6 +23,11 @@
|
|||
#else
|
||||
#define TOY_OPS
|
||||
|
||||
#ifdef MLIR_CALLINTERFACES
|
||||
#else
|
||||
include "mlir/Analysis/CallInterfaces.td"
|
||||
#endif // MLIR_CALLINTERFACES
|
||||
|
||||
#ifdef SHAPE_INFERENCE_INTERFACE
|
||||
#else
|
||||
include "toy/ShapeInferenceInterface.td"
|
||||
|
@ -111,7 +116,27 @@ def AddOp : Toy_Op<"add",
|
|||
>];
|
||||
}
|
||||
|
||||
def GenericCallOp : Toy_Op<"generic_call"> {
|
||||
def CastOp : Toy_Op<"cast",
|
||||
[DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, NoSideEffect,
|
||||
SameOperandsAndResultShape]> {
|
||||
let summary = "shape cast operation";
|
||||
let description = [{
|
||||
The "cast" operation converts a tensor from one type to an equivalent type
|
||||
without changing any data elements. The source and destination types
|
||||
must both be tensor types with the same element type. If both are ranked
|
||||
then the rank should be the same and static dimensions should match. The
|
||||
operation is invalid if converting to a mismatching constant dimension.
|
||||
}];
|
||||
|
||||
let arguments = (ins F64Tensor:$input);
|
||||
let results = (outs F64Tensor:$output);
|
||||
|
||||
// Set the folder bit so that we can fold redundant cast operations.
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def GenericCallOp : Toy_Op<"generic_call",
|
||||
[DeclareOpInterfaceMethods<CallOpInterface>]> {
|
||||
let summary = "generic call operation";
|
||||
let description = [{
|
||||
Generic calls represent calls to a user defined function that needs to
|
||||
|
|
|
@ -64,6 +64,17 @@ struct ToyInlinerInterface : public DialectInlinerInterface {
|
|||
for (const auto &it : llvm::enumerate(returnOp.getOperands()))
|
||||
valuesToRepl[it.index()]->replaceAllUsesWith(it.value());
|
||||
}
|
||||
|
||||
/// Attempts to materialize a conversion for a type mismatch between a call
|
||||
/// from this dialect, and a callable region. This method should generate an
|
||||
/// operation that takes 'input' as the only operand, and produces a single
|
||||
/// result of 'resultType'. If a conversion can not be generated, nullptr
|
||||
/// should be returned.
|
||||
Operation *materializeCallConversion(OpBuilder &builder, Value *input,
|
||||
Type resultType,
|
||||
Location conversionLoc) const final {
|
||||
return builder.create<CastOp>(conversionLoc, resultType, input);
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -94,7 +105,12 @@ static void buildConstantOp(mlir::Builder *builder, mlir::OperationState &state,
|
|||
ConstantOp::build(builder, state, dataType, dataAttribute);
|
||||
}
|
||||
|
||||
/// Verifier for constant operation.
|
||||
/// Infer the output shape of the CastOp, this is required by the shape
|
||||
/// inference interface.
|
||||
void CastOp::inferShapes() { getResult()->setType(getOperand()->getType()); }
|
||||
|
||||
/// Verifier for the constant operation. This corresponds to the `::verify(...)`
|
||||
/// in the op definition.
|
||||
static mlir::LogicalResult verify(ConstantOp op) {
|
||||
// If the return type of the constant is not an unranked tensor, the shape
|
||||
// must match the shape of the attribute holding the data.
|
||||
|
@ -139,6 +155,16 @@ static void buildGenericCallOp(mlir::Builder *builder,
|
|||
state.addAttribute("callee", builder->getSymbolRefAttr(callee));
|
||||
}
|
||||
|
||||
/// Return the callee of the generic call operation, this is required by the
|
||||
/// call interface.
|
||||
CallInterfaceCallable GenericCallOp::getCallableForCallee() {
|
||||
return getAttrOfType<SymbolRefAttr>("callee");
|
||||
}
|
||||
|
||||
/// Get the argument operands to the called function, this is required by the
|
||||
/// call interface.
|
||||
Operation::operand_range GenericCallOp::getArgOperands() { return inputs(); }
|
||||
|
||||
static void buildMulOp(mlir::Builder *builder, mlir::OperationState &state,
|
||||
mlir::Value *lhs, mlir::Value *rhs) {
|
||||
state.addTypes(builder->getTensorType(builder->getF64Type()));
|
||||
|
|
|
@ -80,8 +80,13 @@ public:
|
|||
|
||||
// Ask the operation to infer its output shapes.
|
||||
LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n");
|
||||
auto shapeOp = dyn_cast<ShapeInference>(op);
|
||||
shapeOp.inferShapes();
|
||||
if (auto shapeOp = dyn_cast<ShapeInference>(op)) {
|
||||
shapeOp.inferShapes();
|
||||
} else {
|
||||
op->emitError("unable to infer shape of operation without shape "
|
||||
"inference interface");
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
||||
// If the operation worklist isn't empty, this indicates a failure.
|
||||
|
|
|
@ -32,6 +32,11 @@ namespace {
|
|||
#include "ToyCombine.inc"
|
||||
} // end anonymous namespace
|
||||
|
||||
/// Fold simple cast operations that return the same type as the input.
|
||||
OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
|
||||
return mlir::impl::foldCastOp(*this);
|
||||
}
|
||||
|
||||
/// This is an example of a c++ rewrite pattern for the TransposeOp. It
|
||||
/// optimizes the following scenario: transpose(transpose(x)) -> transpose(x)
|
||||
struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
|
||||
|
|
|
@ -122,9 +122,6 @@ int dumpMLIR() {
|
|||
// Apply any generic pass manager command line options and run the pipeline.
|
||||
applyPassManagerCLOptions(pm);
|
||||
|
||||
// Add a run of the canonicalizer to optimize the mlir module.
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
|
||||
// Inline all functions into main and then delete them.
|
||||
pm.addPass(mlir::createInlinerPass());
|
||||
pm.addPass(mlir::toy::createDeadFunctionEliminationPass());
|
||||
|
@ -132,6 +129,7 @@ int dumpMLIR() {
|
|||
// Now that there is only one function, we can infer the shapes of each of
|
||||
// the operations.
|
||||
pm.addPass(mlir::toy::createShapeInferencePass());
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
|
||||
if (mlir::failed(pm.run(*module)))
|
||||
return 4;
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
// RUN: toyc-ch4 %s -emit=mlir -opt 2>&1 | FileCheck %s
|
||||
|
||||
// Check the result of inlining+shape inference on an input module.
|
||||
|
||||
func @multiply_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> {
|
||||
%0 = "toy.transpose"(%arg1) : (tensor<*xf64>) -> tensor<*xf64>
|
||||
%1 = "toy.mul"(%arg0, %0) : (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64>
|
||||
"toy.return"(%1) : (tensor<*xf64>) -> ()
|
||||
}
|
||||
func @main() {
|
||||
%0 = "toy.constant"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>
|
||||
%1 = "toy.reshape"(%0) : (tensor<2x3xf64>) -> tensor<2x3xf64>
|
||||
%2 = "toy.constant"() {value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>} : () -> tensor<6xf64>
|
||||
%3 = "toy.reshape"(%2) : (tensor<6xf64>) -> tensor<2x3xf64>
|
||||
%4 = "toy.generic_call"(%1, %3) {callee = @multiply_transpose} : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64>
|
||||
%5 = "toy.generic_call"(%3, %1) {callee = @multiply_transpose} : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64>
|
||||
"toy.print"(%5) : (tensor<*xf64>) -> ()
|
||||
"toy.return"() : () -> ()
|
||||
}
|
||||
|
||||
// CHECK-NOT: func @multiply_transpose
|
||||
// CHECK-NOT: tensor<*xf64>
|
||||
|
||||
// CHECK-LABEL: func @main() {
|
||||
// CHECK: [[VAL_0:%.*]] = "toy.constant"() {value = dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>
|
||||
// CHECK: [[VAL_1:%.*]] = "toy.constant"() {value = dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>
|
||||
// CHECK: [[VAL_2:%.*]] = "toy.transpose"([[VAL_0]]) : (tensor<2x3xf64>) -> tensor<3x2xf64>
|
||||
// CHECK: [[VAL_3:%.*]] = "toy.mul"([[VAL_1]], [[VAL_2]]) : (tensor<2x3xf64>, tensor<3x2xf64>) -> tensor<2x2xf64>
|
||||
// CHECK: "toy.print"([[VAL_3]]) : (tensor<2x2xf64>) -> ()
|
||||
// CHECK: "toy.return"() : () -> ()
|
Loading…
Reference in New Issue