diff --git a/mlir/examples/toy/Ch4/CMakeLists.txt b/mlir/examples/toy/Ch4/CMakeLists.txt index 11972e567f1e..dde70db25b61 100644 --- a/mlir/examples/toy/Ch4/CMakeLists.txt +++ b/mlir/examples/toy/Ch4/CMakeLists.txt @@ -1,16 +1,29 @@ +add_subdirectory(include) + set(LLVM_LINK_COMPONENTS Support ) +set(LLVM_TARGET_DEFINITIONS mlir/ToyCombine.td) +mlir_tablegen(ToyCombine.inc -gen-rewriters "-I${CMAKE_CURRENT_SOURCE_DIR}/include") +add_public_tablegen_target(ToyCh4CombineIncGen) + add_toy_chapter(toyc-ch4 toyc.cpp parser/AST.cpp mlir/MLIRGen.cpp - mlir/ToyDialect.cpp + mlir/Dialect.cpp + mlir/DeadFunctionEliminationPass.cpp mlir/ShapeInferencePass.cpp mlir/ToyCombine.cpp ) + +add_dependencies(toyc-ch4 ToyCh4OpsIncGen) +add_dependencies(toyc-ch4 ToyCh4ShapeInferenceInterfaceIncGen) +add_dependencies(toyc-ch4 ToyCh4CombineIncGen) include_directories(include/) +include_directories(${CMAKE_CURRENT_BINARY_DIR}) +include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/) target_link_libraries(toyc-ch4 PRIVATE MLIRAnalysis diff --git a/mlir/examples/toy/Ch4/include/CMakeLists.txt b/mlir/examples/toy/Ch4/include/CMakeLists.txt new file mode 100644 index 000000000000..37c89d0bae96 --- /dev/null +++ b/mlir/examples/toy/Ch4/include/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(toy) diff --git a/mlir/examples/toy/Ch4/include/toy/AST.h b/mlir/examples/toy/Ch4/include/toy/AST.h index 456a32309c40..2ad3392c11ac 100644 --- a/mlir/examples/toy/Ch4/include/toy/AST.h +++ b/mlir/examples/toy/Ch4/include/toy/AST.h @@ -33,10 +33,9 @@ namespace toy { -/// A variable +/// A variable type with shape information. struct VarType { - enum { TY_FLOAT, TY_INT } elt_ty; - std::vector shape; + std::vector shape; }; /// Base class for all expression nodes. @@ -50,9 +49,7 @@ public: Expr_Var, Expr_BinOp, Expr_Call, - Expr_Print, // builtin - Expr_If, - Expr_For, + Expr_Print, }; ExprAST(ExprASTKind kind, Location location) @@ -85,7 +82,7 @@ public: static bool classof(const ExprAST *C) { return C->getKind() == Expr_Num; } }; -/// +/// Expression class for a literal value. class LiteralExprAST : public ExprAST { std::vector> values; std::vector dims; @@ -116,7 +113,7 @@ public: static bool classof(const ExprAST *C) { return C->getKind() == Expr_Var; } }; -/// +/// Expression class for defining a variable. class VarDeclExprAST : public ExprAST { std::string name; VarType type; @@ -136,7 +133,7 @@ public: static bool classof(const ExprAST *C) { return C->getKind() == Expr_VarDecl; } }; -/// +/// Expression class for a return operator. class ReturnExprAST : public ExprAST { llvm::Optional> expr; diff --git a/mlir/examples/toy/Ch4/include/toy/CMakeLists.txt b/mlir/examples/toy/Ch4/include/toy/CMakeLists.txt new file mode 100644 index 000000000000..798d0df1d8d6 --- /dev/null +++ b/mlir/examples/toy/Ch4/include/toy/CMakeLists.txt @@ -0,0 +1,9 @@ +set(LLVM_TARGET_DEFINITIONS Ops.td) +mlir_tablegen(Ops.h.inc -gen-op-decls "-I${CMAKE_CURRENT_SOURCE_DIR}/..") +mlir_tablegen(Ops.cpp.inc -gen-op-defs "-I${CMAKE_CURRENT_SOURCE_DIR}/..") +add_public_tablegen_target(ToyCh4OpsIncGen) + +set(LLVM_TARGET_DEFINITIONS ShapeInferenceInterface.td) +mlir_tablegen(ShapeInferenceOpInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(ShapeInferenceOpInterfaces.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(ToyCh4ShapeInferenceInterfaceIncGen) diff --git a/mlir/examples/toy/Ch4/include/toy/Dialect.h b/mlir/examples/toy/Ch4/include/toy/Dialect.h index b0838870b5a5..da61191c6c0f 100644 --- a/mlir/examples/toy/Ch4/include/toy/Dialect.h +++ b/mlir/examples/toy/Ch4/include/toy/Dialect.h @@ -16,7 +16,7 @@ // ============================================================================= // // This file implements the IR Dialect for the Toy language. -// See g3doc/Tutorials/Toy/Ch-3.md for more information. +// See g3doc/Tutorials/Toy/Ch-2.md for more information. // //===----------------------------------------------------------------------===// @@ -25,325 +25,30 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/Function.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/IR/TypeSupport.h" -#include "mlir/IR/Types.h" +#include "mlir/IR/StandardTypes.h" namespace mlir { -class Builder; -} - namespace toy { /// This is the definition of the Toy dialect. A dialect inherits from -/// mlir::Dialect and register custom operations and types (in its constructor). -/// It can also overriding general behavior of dialects exposed as virtual -/// method, for example regarding verification and parsing/printing. +/// mlir::Dialect and registers custom attributes, operations, and types (in its +/// constructor). It can also override some general behavior exposed via virtual +/// methods. class ToyDialect : public mlir::Dialect { public: explicit ToyDialect(mlir::MLIRContext *ctx); - /// Parse a type registered to this dialect. Overriding this method is - /// required for dialects that have custom types. - /// Technically this is only needed to be able to round-trip to textual IR. - mlir::Type parseType(llvm::StringRef tyData, - mlir::Location loc) const override; - - /// Print a type registered to this dialect. Overriding this method is - /// only required for dialects that have custom types. - /// Technically this is only needed to be able to round-trip to textual IR. - void printType(mlir::Type type, llvm::raw_ostream &os) const override; + /// Provide a utility accessor to the dialect namespace. This is used by + /// several utilities for casting between dialects. + static llvm::StringRef getDialectNamespace() { return "toy"; } }; -//////////////////////////////////////////////////////////////////////////////// -/////////////////////// Custom Types for the Dialect /////////////////////////// -//////////////////////////////////////////////////////////////////////////////// - -namespace detail { -struct ToyArrayTypeStorage; -} - -/// LLVM-style RTTI: one entry per subclass to allow dyn_cast/isa. -enum ToyTypeKind { - // The enum starts at the range reserved for this dialect. - TOY_TYPE = mlir::Type::FIRST_TOY_TYPE, - TOY_ARRAY, -}; - -/// Type for Toy arrays. -/// In MLIR Types are reference to immutable and uniqued objects owned by the -/// MLIRContext. As such `ToyArrayType` only wraps a pointer to an uniqued -/// instance of `ToyArrayTypeStorage` (defined in our implementation file) and -/// provides the public facade API to interact with the type. -class ToyArrayType : public mlir::Type::TypeBase { -public: - using Base::Base; - - /// Returns the dimensions for this array, or and empty range for a generic - /// array. - llvm::ArrayRef getShape(); - - /// Predicate to test if this array is generic (shape haven't been inferred - /// yet). - bool isGeneric() { return getShape().empty(); } - - /// Return the rank of this array (0 if it is generic). - int getRank() { return getShape().size(); } - - /// Return the type of individual elements in the array. - mlir::Type getElementType(); - - /// Get the unique instance of this Type from the context. - /// A ToyArrayType is only defined by the shape of the array. - static ToyArrayType get(mlir::MLIRContext *context, - llvm::ArrayRef shape = {}); - - /// Support method to enable LLVM-style RTTI type casting. - static bool kindof(unsigned kind) { return kind == ToyTypeKind::TOY_ARRAY; } -}; - -//////////////////////////////////////////////////////////////////////////////// -//////////////////// Custom Operations for the Dialect ///////////////////////// -//////////////////////////////////////////////////////////////////////////////// - -/// Constant operation turns a literal into an SSA value. The data is attached -/// to the operation as an attribute. For example: -/// -/// %0 = "toy.constant"() -/// {value: dense, [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]>} -/// : () -> !toy.array<2, 3> -/// -/// An operation inherits from `class Op` and specifies optional traits. Here we -/// indicate that `toy.constant` does not have any operands and returns a single -/// result. The traits provide some utilities methods for the operation, for -/// instance we will be able to use `getResult()`, but `getOperand()` won't be -/// available. -class ConstantOp : public mlir::Op { -public: - /// This is the name used by MLIR to match an operation to this class during - /// parsing. - static llvm::StringRef getOperationName() { return "toy.constant"; } - - /// The operation can have extra verification beyond the traits they define. - mlir::LogicalResult verify(); - - /// Interface to mlir::Builder::create(...) - /// This method populates the `state` that MLIR uses to create operations. - /// The `toy.constant` operation does not have arguments but attaches a - /// constant array as an attribute and returns it as an SSA value. - static void build(mlir::Builder *builder, mlir::OperationState &state, - llvm::ArrayRef shape, - mlir::DenseElementsAttr value); - - /// Similar to the one above, but takes a single float and returns a - /// !toy.array<1>. - static void build(mlir::Builder *builder, mlir::OperationState &state, - mlir::FloatAttr value); - - /// Inherit constructor. - using Op::Op; -}; - -/// Generic calls represent calls to a user defined function that needs to -/// be specialized for the shape of its arguments. The callee name is attached -/// as a literal string as an attribute. The arguments list must match the -/// arguments expected by the callee. For example: -/// -/// %4 = "toy.generic_call"(%1, %3) {callee: "my_func"} -/// : (!toy.array<2, 3>, !toy.array<2, 3>) -> !toy<"array"> -/// -/// This is only valid if a function named "my_func" exists and takes two -/// arguments. -class GenericCallOp - : public mlir::Op { -public: - /// MLIR will use this to register the operation with the parser/printer. - static llvm::StringRef getOperationName() { return "toy.generic_call"; } - - /// Operations can add custom verification beyond the traits they define. - mlir::LogicalResult verify(); - - /// Interface to the builder to allow: - /// mlir::Builder::create(...) - /// This method populate the `state` that MLIR use to create operations. - /// The `toy.generic_call` operation accepts a callee name and a list of - /// arguments for the call. - static void build(mlir::Builder *builder, mlir::OperationState &state, - llvm::StringRef callee, - llvm::ArrayRef arguments); - - /// Return the name of the callee. - llvm::StringRef getCalleeName(); - - /// Inherit constructor. - using Op::Op; -}; - -/// Return operations terminate blocks (and functions as well). They take a -/// single argument and the type must match the function return type. -class ReturnOp - : public mlir::Op { -public: - static llvm::StringRef getOperationName() { return "toy.return"; } - - /// Operations can add custom verification beyond the traits they define. - mlir::LogicalResult verify(); - - /// Interface to mlir::Builder::create(...) - /// This method populate the `state` that MLIR use to create operations. - /// The `toy.return` operation accepts an optional single array as an argument - /// and does not have any returned value. - static void build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *value = nullptr); - - /// Return true if there is a returned value. - bool hasOperand() { return 0 != getNumOperands(); } - - /// Helper to return the optional operand. Caller must check if the operand - /// is present before calling this. - mlir::Value *getOperand() { return getOperation()->getOperand(0); } - - /// Inherit constructor. - using Op::Op; -}; - -/// The print builtin takes a single array argument and does not return any. -class PrintOp : public mlir::Op { -public: - static llvm::StringRef getOperationName() { return "toy.print"; } - - /// Operations can add custom verification beyond the traits they define. - mlir::LogicalResult verify(); - - /// Interface to mlir::Builder::create(...) - /// This method populate the `state` that MLIR use to create operations. - /// The `toy.print` operation accepts a single array as argument and does - /// not have any returned value. - static void build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *value); - - /// Inherit constructor. - using Op::Op; -}; - -class TransposeOp : public mlir::Op { -public: - static llvm::StringRef getOperationName() { return "toy.transpose"; } - - /// Operation can add custom verification beyond the traits they define. - mlir::LogicalResult verify(); - - /// Interface to mlir::Builder::create(...) - /// This method populate the `state` that MLIR use to create operations. - /// The `toy.transpose` operation accepts a single array as argument and - /// returns the transposed array as its only result. - static void build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *value); - - // Register our patterns for rewrite by the Canonicalization framework. - static void - getCanonicalizationPatterns(mlir::OwningRewritePatternList &results, - mlir::MLIRContext *context); - - /// Inherit constructor. - using Op::Op; -}; - -/// Reshape operation is transforming its input array into a new array with the -/// same number of elements but different shapes. For example: -/// -/// %0 = "toy.reshape"(%arg1) : (!toy.array<10>) -> !toy.array<5, 2> -/// -class ReshapeOp : public mlir::Op { -public: - static llvm::StringRef getOperationName() { return "toy.reshape"; } - - /// Operation can add custom verification beyond the traits they define. - mlir::LogicalResult verify(); - - /// Interface to mlir::Builder::create(...) - /// This method populate the `state` that MLIR use to create operations. - /// The `toy.reshape` operation accepts a single array as argument and - /// returns the array with the specified reshapedType as its only result. - static void build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *value, ToyArrayType reshapedType); - - // Register our patterns for rewrite by the Canonicalization framework. - static void - getCanonicalizationPatterns(mlir::OwningRewritePatternList &results, - mlir::MLIRContext *context); - - /// Inherit constructor. - using Op::Op; -}; - -/// Binary operation implementing a multiplication. For two-dimensional array -/// a matrix multiplication is implemented, while for one dimensional array a -/// dot product is performed. -class MulOp : public mlir::Op::Impl, - mlir::OpTrait::OneResult, - mlir::OpTrait::HasNoSideEffect> { -public: - static llvm::StringRef getOperationName() { return "toy.mul"; } - - /// Operation can add custom verification beyond the traits they define. - mlir::LogicalResult verify(); - - /// Interface to mlir::Builder::create(...) - /// This method populate the `state` that MLIR use to create operations. - /// The `toy.mul` operation accepts two operands as argument and returns - /// a single value. - static void build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *lhs, mlir::Value *rhs); - - /// Convenience accessor for LHS of the expression. - mlir::Value *getLHS() { return getOperand(0); } - - /// Convenience accessor for RHS of the expression. - mlir::Value *getRHS() { return getOperand(1); } - - /// Inherit constructor. - using Op::Op; -}; - -/// Element wise addition of two arrays. The shape must match. -class AddOp : public mlir::Op::Impl, - mlir::OpTrait::OneResult, - mlir::OpTrait::HasNoSideEffect> { -public: - static llvm::StringRef getOperationName() { return "toy.add"; } - - /// Operation can add custom verification beyond the traits they define. - mlir::LogicalResult verify(); - - /// Interface to mlir::Builder::create(...) - /// This method populate the `state` that MLIR use to create operations. - /// The `toy.mul` operation accepts two operands as argument and returns - /// a single value. - static void build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *lhs, mlir::Value *rhs); - - /// Convenience accessor for LHS of the expression. - mlir::Value *getLHS() { return getOperand(0); } - - /// Convenience accessor for RHS of the expression. - mlir::Value *getRHS() { return getOperand(1); } - - /// Inherit constructor. - using Op::Op; -}; +/// Include the auto-generated header file containing the declarations of the +/// toy operations. +#define GET_OP_CLASSES +#include "toy/Ops.h.inc" } // end namespace toy +} // end namespace mlir #endif // MLIR_TUTORIAL_TOY_DIALECT_H_ diff --git a/mlir/examples/toy/Ch4/include/toy/Lexer.h b/mlir/examples/toy/Ch4/include/toy/Lexer.h index d73adb9706b7..21f92614912e 100644 --- a/mlir/examples/toy/Ch4/include/toy/Lexer.h +++ b/mlir/examples/toy/Ch4/include/toy/Lexer.h @@ -31,7 +31,7 @@ namespace toy { /// Structure definition a location in a file. struct Location { - std::shared_ptr file; ///< filename + std::shared_ptr file; ///< filename. int line; ///< line number. int col; ///< column number. }; diff --git a/mlir/examples/toy/Ch4/include/toy/Ops.td b/mlir/examples/toy/Ch4/include/toy/Ops.td new file mode 100644 index 000000000000..f0140d70f9bd --- /dev/null +++ b/mlir/examples/toy/Ch4/include/toy/Ops.td @@ -0,0 +1,285 @@ +//===- Ops.td - Toy dialect operation definitions ----------*- tablegen -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// Defines the operations of the Toy dialect. +// +//===----------------------------------------------------------------------===// + +#ifdef TOY_OPS +#else +#define TOY_OPS + +#ifdef SHAPE_INFERENCE_INTERFACE +#else +include "toy/ShapeInferenceInterface.td" +#endif // SHAPE_INFERENCE_INTERFACE + +// Provide a definition of the 'toy' dialect in the ODS framework so that we +// can define our operations. +def Toy_Dialect : Dialect { + let name = "toy"; + let cppNamespace = "toy"; +} + +// Base class for toy dialect operations. This operation inherits from the base +// `Op` class in OpBase.td, and provides: +// * The parent dialect of the operation. +// * The mnemonic for the operation, or the name without the dialect prefix. +// * A list of traits for the operation. +class Toy_Op traits = []> : + Op; + +//===----------------------------------------------------------------------===// +// Toy Operations +//===----------------------------------------------------------------------===// + +// We define a toy operation by inherting from our base 'Toy_Op' class above. +// Here we provide the mnemonic and a list of traits for the operation. The +// constant operation is marked as 'NoSideEffect' as it is a pure operation +// and may be removed if dead. +def ConstantOp : Toy_Op<"constant", [NoSideEffect]> { + // Provide a summary and description for this operation. This can be used to + // auto-generate documenatation of the operations within our dialect. + let summary = "constant"; + let description = [{ + Constant operation turns a literal into an SSA value. The data is attached + to the operation as an attribute. For example: + + ```mlir + %0 = "toy.constant"() + { value = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf64> } + : () -> tensor<2x3xf64> + ``` + }]; + + // The constant operation takes an attribute as the only input. + let arguments = (ins F64ElementsAttr:$value); + + // The constant operation returns a single value of TensorType. + let results = (outs F64Tensor); + + // Add custom build methods for the constant operation. These method populates + // the `state` that MLIR uses to create operations, i.e. these are used when + // using `builder.create(...)`. + let builders = [ + // Build a constant with a given constant tensor value. + OpBuilder<"Builder *builder, OperationState &result, " + "DenseElementsAttr value", [{ + build(builder, result, value.getType(), value); + }]>, + + // Build a constant with a given constant floating-point value. + OpBuilder<"Builder *builder, OperationState &result, double value", [{ + buildConstantOp(builder, result, value); + }]> + ]; + + // Invoke a static verify method to verify this constant operation. + let verifier = [{ return ::verify(*this); }]; +} + +def AddOp : Toy_Op<"add", [NoSideEffect]> { + let summary = "element-wise addition operation"; + let description = [{ + The "add" operation performs element-wise addition between two tensors. + The shapes of the tensor operands are expected to match. + }]; + + let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs); + let results = (outs F64Tensor); + + // Allow building an AddOp with from the two input operands. + let builders = [ + OpBuilder<"Builder *b, OperationState &result, Value *lhs, Value *rhs", [{ + buildAddOp(b, result, lhs, rhs); + }] + >]; + let extraClassDeclaration = [{ + void inferShapes() { + getResult()->setType(getOperand(0)->getType()); + return; + } + }]; +} + +def GenericCallOp : Toy_Op<"generic_call"> { + let summary = "generic call operation"; + let description = [{ + Generic calls represent calls to a user defined function that needs to + be specialized for the shape of its arguments. The callee name is attached + as a symbol reference via an attribute. The arguments list must match the + arguments expected by the callee. For example: + + ```mlir + %4 = "toy.generic_call"(%1, %3) {callee = @my_func} + : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> + ``` + + This is only valid if a function named "my_func" exists and takes two + arguments. + }]; + + // The generic call operation takes a symbol reference attribute as the + // callee, and inputs for the call. + let arguments = (ins SymbolRefAttr:$callee, Variadic:$inputs); + + // The generic call operation returns a single value of TensorType. + let results = (outs F64Tensor); + + // Add custom build methods for the generic call operation. + let builders = [ + // Build a constant with a given constant tensor value. + OpBuilder<"Builder *builder, OperationState &result, " + "StringRef callee, ArrayRef arguments", [{ + buildGenericCallOp(builder, result, callee, arguments); + }]> + ]; +} + +def MulOp : Toy_Op<"mul", [NoSideEffect]> { + let summary = "element-wise multiplication operation"; + let description = [{ + The "mul" operation performs element-wise multiplication between two + tensors. The shapes of the tensor operands are expected to match. + }]; + + let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs); + let results = (outs F64Tensor); + + // Allow building a MulOp with from the two input operands. + let builders = [ + OpBuilder<"Builder *b, OperationState &result, Value *lhs, Value *rhs", [{ + buildMulOp(b, result, lhs, rhs); + }] + >]; + let extraClassDeclaration = [{ + void inferShapes() { + auto lhs = getOperand(0)->getType().cast(); + auto rhs = getOperand(1)->getType().cast(); + auto lhsRank = lhs.getShape().size(); + auto rhsRank = rhs.getShape().size(); + if (lhsRank != rhsRank) { + return; + } + SmallVector dims; + if (lhsRank == 1) { + // dot product, result shape is <1> + dims.push_back(1); + } else { + if (lhsRank != 2) { + return; + } + dims.push_back(lhs.getShape()[0]); + dims.push_back(rhs.getShape()[1]); + } + getResult()->setType(RankedTensorType::get(dims, lhs.getElementType())); + return; + } + }]; +} + +def PrintOp : Toy_Op<"print"> { + let summary = "print operation"; + let description = [{ + The "print" builtin operation prints a given input tensor, and produces + no results. + }]; + + // The print operation takes an input tensor to print. + let arguments = (ins F64Tensor:$input); +} + +def ReshapeOp : Toy_Op<"reshape", [NoSideEffect]> { + let summary = "tensor reshape operation"; + let description = [{ + Reshape operation is transforming its input tensor into a new tensor with + the same number of elements but different shapes. For example: + + ```mlir + %0 = "toy.reshape"(%arg1) : (tensor<10xf64>) -> tensor<5x2xf64> + ``` + }]; + + let arguments = (ins F64Tensor:$input); + let hasCanonicalizer = 1; + + // We expect that the reshape operation returns a statically shaped tensor. + let results = (outs StaticShapeTensorOf<[F64]>); +} + +def ReturnOp : Toy_Op<"return", [Terminator, HasParent<"FuncOp">]> { + let summary = "return operation"; + let description = [{ + The "return" operation represents a return operation within a function. + The operation takes an optional tensor operand and produces no results. + The operand type must match the signature of the function that contains + the operation. For example: + + ```mlir + func @foo() -> tensor<2xf64> { + ... + toy.return %0 : tensor<2xf64> + } + ``` + }]; + + // The return operation takes an optional input operand to return. This + // value must match the return type of the enclosing function. + let arguments = (ins Variadic:$input); + + // Allow building a ReturnOp with no return operand. + let builders = [OpBuilder< + "Builder *b, OperationState &result", [{ build(b, result, llvm::None); }] + >]; + + // Provide extra utility definitions on the c++ operation class definition. + let extraClassDeclaration = [{ + bool hasOperand() { return getNumOperands() != 0; } + }]; + + // Invoke a static verify method to verify this return operation. + let verifier = [{ return ::verify(*this); }]; +} + +def TransposeOp : Toy_Op<"transpose", [NoSideEffect]> { + let summary = "transpose operation"; + + let arguments = (ins F64Tensor:$input); + let results = (outs F64Tensor); + let hasCanonicalizer = 1; + + // Allow building a TransposeOp with from the two input operands. + let builders = [ + OpBuilder<"Builder *b, OperationState &result, Value *input", [{ + buildTransposeOp(b, result, input); + }] + >]; + let extraClassDeclaration = [{ + void inferShapes() { + SmallVector dims; + auto arrayTy = getOperand()->getType().cast(); + dims.insert(dims.end(), arrayTy.getShape().begin(), + arrayTy.getShape().end()); + if (dims.size() == 2) + std::swap(dims[0], dims[1]); + getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType())); + return; + } + }]; +} + +#endif // TOY_OPS diff --git a/mlir/examples/toy/Ch4/include/toy/Passes.h b/mlir/examples/toy/Ch4/include/toy/Passes.h index 93cf0d5ba155..8c8365d6882c 100644 --- a/mlir/examples/toy/Ch4/include/toy/Passes.h +++ b/mlir/examples/toy/Ch4/include/toy/Passes.h @@ -26,10 +26,11 @@ namespace mlir { class Pass; -} // namespace mlir namespace toy { -std::unique_ptr createShapeInferencePass(); -} // namespace toy +std::unique_ptr createShapeInferencePass(); +std::unique_ptr createDeadFunctionEliminationPass(); +} // end namespace toy +} // end namespace mlir #endif // MLIR_TUTORIAL_TOY_PASSES_H diff --git a/mlir/examples/toy/Ch4/include/toy/ShapeInferenceInterface.td b/mlir/examples/toy/Ch4/include/toy/ShapeInferenceInterface.td new file mode 100644 index 000000000000..2040cc44fdf4 --- /dev/null +++ b/mlir/examples/toy/Ch4/include/toy/ShapeInferenceInterface.td @@ -0,0 +1,38 @@ +//===- ShapeInferenceInterface.td - Operation Interface for Shape Inference ----------*- tablegen -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// Defines the operations of the Shape Inference Op Interface. +// +//===----------------------------------------------------------------------===// + +#ifdef SHAPE_INFERENCE_INTERFACE +#else +#define SHAPE_INFERENCE_INTERFACE + +#ifdef OP_BASE +#else +include "mlir/IR/OpBase.td" +#endif // OP_BASE + +def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> { + let methods = [ + InterfaceMethod<"Infer output shape for the current operation.", + "void", "inferShapes", (ins), [{}]> + ]; +} + +#endif // SHAPE_INFERENCE_INTERFACE diff --git a/mlir/examples/toy/Ch4/mlir/DeadFunctionEliminationPass.cpp b/mlir/examples/toy/Ch4/mlir/DeadFunctionEliminationPass.cpp new file mode 100644 index 000000000000..e7e64ce5b3d4 --- /dev/null +++ b/mlir/examples/toy/Ch4/mlir/DeadFunctionEliminationPass.cpp @@ -0,0 +1,61 @@ +//===- DeadFunctionEliminationPass.cpp - Eliminate inlined functions ------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements a Module level pass performing dead function +// elimination. This is required as a post-processing step after function +// inlining. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/Verifier.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LogicalResult.h" +#include "toy/Passes.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringSet.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +#include + +namespace { +class DeadFunctionEliminationPass + : public mlir::ModulePass { +public: + void runOnModule() override { + std::string str = "main"; + auto module = getModule(); + for (auto &f : module) { + // eliminate dead functions that are not main + if (str.find(f.getName().getStringRef()) == std::string::npos) + f.erase(); + } + } +}; +} // namespace + +/// Create a pass that eliminates inlined functions in toy. +std::unique_ptr mlir::toy::createDeadFunctionEliminationPass() { + return std::make_unique(); +} diff --git a/mlir/examples/toy/Ch4/mlir/Dialect.cpp b/mlir/examples/toy/Ch4/mlir/Dialect.cpp new file mode 100644 index 000000000000..63eee4eefb8b --- /dev/null +++ b/mlir/examples/toy/Ch4/mlir/Dialect.cpp @@ -0,0 +1,190 @@ +//===- Dialect.cpp - Toy IR Dialect registration in MLIR ------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements the dialect for the Toy IR: custom type parsing and +// operation verification. +// +//===----------------------------------------------------------------------===// + +#include "toy/Dialect.h" + +#include "mlir/IR/Builders.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Transforms/InliningUtils.h" + +using namespace mlir; +using namespace mlir::toy; + +//===----------------------------------------------------------------------===// +// ToyInlinerInterface +//===----------------------------------------------------------------------===// + +/// This class defines the interface for handling inlining with Toy +/// operations. +struct ToyInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + //===--------------------------------------------------------------------===// + // Analysis Hooks + //===--------------------------------------------------------------------===// + + /// All operations within toy can be inlined. + bool isLegalToInline(Operation *, Region *, + BlockAndValueMapping &) const final { + return true; + } + + //===--------------------------------------------------------------------===// + // Transformation Hooks + //===--------------------------------------------------------------------===// + + /// Handle the given inlined terminator(toy.return) by replacing it with a new + /// operation as necessary. + void handleTerminator(Operation *op, + ArrayRef valuesToRepl) const final { + // Only "toy.return" needs to be handled here. + auto returnOp = cast(op); + + // Replace the values directly with the return operands. + assert(returnOp.getNumOperands() == valuesToRepl.size()); + for (const auto &it : llvm::enumerate(returnOp.getOperands())) + valuesToRepl[it.index()]->replaceAllUsesWith(it.value()); + } +}; + +//===----------------------------------------------------------------------===// +// ToyDialect +//===----------------------------------------------------------------------===// + +/// Dialect creation, the instance will be owned by the context. This is the +/// point of registration of custom types and operations for the dialect. +ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) { + addOperations< +#define GET_OP_LIST +#include "toy/Ops.cpp.inc" + >(); + addInterfaces(); +} + +//===----------------------------------------------------------------------===// +// Toy Operations +//===----------------------------------------------------------------------===// + +/// Build a constant operation. +/// The builder is passed as an argument, so is the state that this method is +/// expected to fill in order to build the operation. +static void buildConstantOp(mlir::Builder *builder, mlir::OperationState &state, + double value) { + auto dataType = builder->getTensorType({}, builder->getF64Type()); + auto dataAttribute = DenseElementsAttr::get(dataType, value); + ConstantOp::build(builder, state, dataType, dataAttribute); +} + +/// Verifier for constant operation. +static mlir::LogicalResult verify(ConstantOp op) { + // If the return type of the constant is not an unranked tensor, the shape + // must match the shape of the attribute holding the data. + auto resultType = op.getResult()->getType().cast(); + if (!resultType) + return success(); + + auto attrType = op.value().getType().cast(); + if (attrType.getRank() != resultType.getRank()) { + return op.emitOpError( + "return type must match the one of the attached value " + "attribute: ") + << attrType.getRank() << " != " << resultType.getRank(); + } + for (int dim = 0; dim < attrType.getRank(); ++dim) { + if (attrType.getShape()[dim] != resultType.getShape()[dim]) { + return op.emitOpError( + "return type shape mismatches its attribute at dimension ") + << dim << ": " << attrType.getShape()[dim] + << " != " << resultType.getShape()[dim]; + } + } + return mlir::success(); +} + +static void buildAddOp(mlir::Builder *builder, mlir::OperationState &state, + mlir::Value *lhs, mlir::Value *rhs) { + state.addTypes(builder->getTensorType(builder->getF64Type())); + state.addOperands({lhs, rhs}); +} + +static void buildGenericCallOp(mlir::Builder *builder, + mlir::OperationState &state, StringRef callee, + ArrayRef arguments) { + // Generic call always returns an unranked Tensor initially. + state.addTypes(builder->getTensorType(builder->getF64Type())); + state.addOperands(arguments); + state.addAttribute("callee", builder->getSymbolRefAttr(callee)); +} + +static void buildMulOp(mlir::Builder *builder, mlir::OperationState &state, + mlir::Value *lhs, mlir::Value *rhs) { + state.addTypes(builder->getTensorType(builder->getF64Type())); + state.addOperands({lhs, rhs}); +} + +static mlir::LogicalResult verify(ReturnOp op) { + // We know that the parent operation is a function, because of the 'HasParent' + // trait attached to the operation definition. + auto function = cast(op.getParentOp()); + + /// ReturnOps can only have a single optional operand. + if (op.getNumOperands() > 1) + return op.emitOpError() << "expects at most 1 return operand"; + + // The operand number and types must match the function signature. + const auto &results = function.getType().getResults(); + if (op.getNumOperands() != results.size()) + return op.emitOpError() + << "does not return the same number of values (" + << op.getNumOperands() << ") as the enclosing function (" + << results.size() << ")"; + + // If the operation does not have an input, we are done. + if (!op.hasOperand()) + return mlir::success(); + + auto inputType = *op.operand_type_begin(); + auto resultType = results.front(); + + // Check that the result type of the function matches the operand type. + if (inputType == resultType || inputType.isa() || + resultType.isa()) + return mlir::success(); + + return op.emitError() << "type of return operand (" + << *op.operand_type_begin() + << ") doesn't match function result type (" + << results.front() << ")"; +} + +static void buildTransposeOp(mlir::Builder *builder, + mlir::OperationState &state, mlir::Value *value) { + state.addTypes(builder->getTensorType(builder->getF64Type())); + state.addOperands(value); +} + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "toy/Ops.cpp.inc" diff --git a/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp index c66335a70d5e..ace52aff2bf0 100644 --- a/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp @@ -25,30 +25,30 @@ #include "toy/Dialect.h" #include "mlir/Analysis/Verifier.h" -#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" -#include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" #include "mlir/IR/StandardTypes.h" -#include "mlir/IR/Types.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopedHashTable.h" #include "llvm/Support/raw_ostream.h" #include +using namespace mlir::toy; using namespace toy; + +using llvm::ArrayRef; using llvm::cast; using llvm::dyn_cast; using llvm::isa; +using llvm::makeArrayRef; using llvm::ScopedHashTableScope; using llvm::SmallVector; using llvm::StringRef; using llvm::Twine; -using std::make_unique; namespace { @@ -57,56 +57,43 @@ namespace { /// This will emit operations that are specific to the Toy language, preserving /// the semantics of the language and (hopefully) allow to perform accurate /// analysis and transformation based on these high level semantics. -/// -/// At this point we take advantage of the "raw" MLIR APIs to create operations -/// that haven't been registered in any way with MLIR. These operations are -/// unknown to MLIR, custom passes could operate by string-matching the name of -/// these operations, but no other type checking or semantic is associated with -/// them natively by MLIR. class MLIRGenImpl { public: - MLIRGenImpl(mlir::MLIRContext &context) : context(context) {} + MLIRGenImpl(mlir::MLIRContext &context) : builder(&context) {} /// Public API: convert the AST for a Toy module (source file) to an MLIR - /// Module. - mlir::OwningModuleRef mlirGen(ModuleAST &moduleAST) { + /// Module operation. + mlir::ModuleOp mlirGen(ModuleAST &moduleAST) { // We create an empty MLIR module and codegen functions one at a time and // add them to the module. - theModule = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)); + theModule = mlir::ModuleOp::create(builder.getUnknownLoc()); for (FunctionAST &F : moduleAST) { auto func = mlirGen(F); if (!func) return nullptr; - theModule->push_back(func); + theModule.push_back(func); } - // FIXME: (in the next chapter...) without registering a dialect in MLIR, - // this won't do much, but it should at least check some structural - // properties. - if (failed(mlir::verify(*theModule))) { - emitError(mlir::UnknownLoc::get(&context), "module verification error"); + // Verify the module after we have finished constructing it, this will check + // the structural properties of the IR and invoke any specific verifiers we + // have on the Toy operations. + if (failed(mlir::verify(theModule))) { + theModule.emitError("Module verification error"); return nullptr; } - return std::move(theModule); + return theModule; } private: - /// In MLIR (like in LLVM) a "context" object holds the memory allocation and - /// the ownership of many internal structure of the IR and provide a level - /// of "uniquing" across multiple modules (types for instance). - mlir::MLIRContext &context; + /// A "module" matches a Toy source file: containing a list of functions. + mlir::ModuleOp theModule; - /// A "module" matches a source file: it contains a list of functions. - mlir::OwningModuleRef theModule; - - /// The builder is a helper class to create IR inside a function. It is - /// re-initialized every time we enter a function and kept around as a - /// convenience for emitting individual operations. - /// The builder is stateful, in particular it keeps an "insertion point": - /// this is where the next operations will be introduced. - std::unique_ptr builder; + /// The builder is a helper class to create IR inside a function. The builder + /// is stateful, in particular it keeeps an "insertion point": this is where + /// the next operations will be introduced. + mlir::OpBuilder builder; /// The symbol table maps a variable name to a value in the current scope. /// Entering a function creates a new scope, and the function arguments are @@ -116,37 +103,35 @@ private: /// Helper conversion for a Toy AST location to an MLIR location. mlir::Location loc(Location loc) { - return mlir::FileLineColLoc::get(mlir::Identifier::get(*loc.file, &context), - loc.line, loc.col, &context); + return builder.getFileLineColLoc(builder.getIdentifier(*loc.file), loc.line, + loc.col); } - /// Declare a variable in the current scope, return true if the variable + /// Declare a variable in the current scope, return success if the variable /// wasn't declared yet. - bool declare(llvm::StringRef var, mlir::Value *value) { - if (symbolTable.count(var)) { - return false; - } + mlir::LogicalResult declare(llvm::StringRef var, mlir::Value *value) { + if (symbolTable.count(var)) + return mlir::failure(); symbolTable.insert(var, value); - return true; + return mlir::success(); } /// Create the prototype for an MLIR function with as many arguments as the /// provided Toy AST prototype. mlir::FuncOp mlirGen(PrototypeAST &proto) { + auto location = loc(proto.loc()); + // This is a generic function, the return type will be inferred later. - llvm::SmallVector ret_types; - // Arguments type is uniformly a generic array. + // Arguments type are uniformly unranked tensors. llvm::SmallVector arg_types(proto.getArgs().size(), getType(VarType{})); - auto func_type = mlir::FunctionType::get(arg_types, ret_types, &context); - auto function = mlir::FuncOp::create(loc(proto.loc()), proto.getName(), - func_type, /* attrs = */ {}); + auto func_type = builder.getFunctionType(arg_types, llvm::None); + auto function = mlir::FuncOp::create(location, proto.getName(), func_type); // Mark the function as generic: it'll require type specialization for every // call site. if (function.getNumArguments()) - function.setAttr("toy.generic", mlir::BoolAttr::get(true, &context)); - + function.setAttr("toy.generic", builder.getUnitAttr()); return function; } @@ -165,29 +150,39 @@ private: // argument list as the function itself. auto &entryBlock = *function.addEntryBlock(); auto &protoArgs = funcAST.getProto()->getArgs(); + // Declare all the function arguments in the symbol table. for (const auto &name_value : llvm::zip(protoArgs, entryBlock.getArguments())) { - declare(std::get<0>(name_value)->getName(), std::get<1>(name_value)); + if (failed(declare(std::get<0>(name_value)->getName(), + std::get<1>(name_value)))) + return nullptr; } - // Create a builder for the function, it will be used throughout the codegen - // to create operations in this function. - builder = std::make_unique(function.getBody()); + // Set the insertion point in the builder to the beginning of the function + // body, it will be used throughout the codegen to create operations in this + // function. + builder.setInsertionPointToStart(&entryBlock); // Emit the body of the function. - if (!mlirGen(*funcAST.getBody())) { + if (mlir::failed(mlirGen(*funcAST.getBody()))) { function.erase(); return nullptr; } - // Implicitly return void if no return statement was emited. + // Implicitly return void if no return statement was emitted. // FIXME: we may fix the parser instead to always return the last expression // (this would possibly help the REPL case later) - if (function.getBlocks().back().back().getName().getStringRef() != - "toy.return") { - ReturnExprAST fakeRet(funcAST.getProto()->loc(), llvm::None); - mlirGen(fakeRet); + ReturnOp returnOp; + if (!entryBlock.empty()) + returnOp = dyn_cast(entryBlock.back()); + if (!returnOp) { + builder.create(loc(funcAST.getProto()->loc())); + } else if (returnOp.hasOperand()) { + // Otherwise, if this return operation has an operand then add a result to + // the function. + function.setType(builder.getFunctionType(function.getType().getInputs(), + getType(VarType{}))); } return function; @@ -206,11 +201,11 @@ private: // and the result value is returned. If an error occurs we get a nullptr // and propagate. // - mlir::Value *L = mlirGen(*binop.getLHS()); - if (!L) + mlir::Value *lhs = mlirGen(*binop.getLHS()); + if (!lhs) return nullptr; - mlir::Value *R = mlirGen(*binop.getRHS()); - if (!R) + mlir::Value *rhs = mlirGen(*binop.getRHS()); + if (!rhs) return nullptr; auto location = loc(binop.loc()); @@ -218,123 +213,112 @@ private: // support '+' and '*'. switch (binop.getOp()) { case '+': - return builder->create(location, L, R).getResult(); - break; + return builder.create(location, lhs, rhs); case '*': - return builder->create(location, L, R).getResult(); - default: - emitError(location, "error: invalid binary operator '") - << binop.getOp() << "'"; - return nullptr; + return builder.create(location, lhs, rhs); } + + emitError(location, "invalid binary operator '") << binop.getOp() << "'"; + return nullptr; } - // This is a reference to a variable in an expression. The variable is - // expected to have been declared and so should have a value in the symbol - // table, otherwise emit an error and return nullptr. + /// This is a reference to a variable in an expression. The variable is + /// expected to have been declared and so should have a value in the symbol + /// table, otherwise emit an error and return nullptr. mlir::Value *mlirGen(VariableExprAST &expr) { - if (symbolTable.count(expr.getName())) - return symbolTable.lookup(expr.getName()); - emitError(loc(expr.loc()), "error: unknown variable '") + if (auto *variable = symbolTable.lookup(expr.getName())) + return variable; + + emitError(loc(expr.loc()), "Error: unknown variable '") << expr.getName() << "'"; return nullptr; } - // Emit a return operation, return true on success. - bool mlirGen(ReturnExprAST &ret) { + /// Emit a return operation. This will return failure if any generation fails. + mlir::LogicalResult mlirGen(ReturnExprAST &ret) { auto location = loc(ret.loc()); - // `return` takes an optional expression, we need to account for it here. - if (!ret.getExpr().hasValue()) { - builder->create(location); - return true; + + // 'return' takes an optional expression, handle that case here. + mlir::Value *expr = nullptr; + if (ret.getExpr().hasValue()) { + if (!(expr = mlirGen(*ret.getExpr().getValue()))) + return mlir::failure(); } - auto *expr = mlirGen(*ret.getExpr().getValue()); - if (!expr) - return false; - builder->create(location, expr); - return true; + + // Otherwise, this return operation has zero operands. + builder.create(location, expr ? makeArrayRef(expr) + : ArrayRef()); + return mlir::success(); } - // Emit a literal/constant array. It will be emitted as a flattened array of - // data in an Attribute attached to a `toy.constant` operation. - // See documentation on [Attributes](LangRef.md#attributes) for more details. - // Here is an excerpt: - // - // Attributes are the mechanism for specifying constant data in MLIR in - // places where a variable is never allowed [...]. They consist of a name - // and a [concrete attribute value](#attribute-values). It is possible to - // attach attributes to operations, functions, and function arguments. The - // set of expected attributes, their structure, and their interpretation - // are all contextually dependent on what they are attached to. - // - // Example, the source level statement: - // var a<2, 3> = [[1, 2, 3], [4, 5, 6]]; - // will be converted to: - // %0 = "toy.constant"() {value: dense, - // [[1.000000e+00, 2.000000e+00, 3.000000e+00], - // [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> memref<2x3xf64> - // + /// Emit a literal/constant array. It will be emitted as a flattened array of + /// data in an Attribute attached to a `toy.constant` operation. + /// See documentation on [Attributes](LangRef.md#attributes) for more details. + /// Here is an excerpt: + /// + /// Attributes are the mechanism for specifying constant data in MLIR in + /// places where a variable is never allowed [...]. They consist of a name + /// and a concrete attribute value. The set of expected attributes, their + /// structure, and their interpretation are all contextually dependent on + /// what they are attached to. + /// + /// Example, the source level statement: + /// var a<2, 3> = [[1, 2, 3], [4, 5, 6]]; + /// will be converted to: + /// %0 = "toy.constant"() {value: dense, + /// [[1.000000e+00, 2.000000e+00, 3.000000e+00], + /// [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> tensor<2x3xf64> + /// mlir::Value *mlirGen(LiteralExprAST &lit) { - auto location = loc(lit.loc()); - // The attribute is a vector with an attribute per element (number) in the - // array, see `collectData()` below for more details. - std::vector data; + auto type = getType(lit.getDims()); + + // The attribute is a vector with a floating point value per element + // (number) in the array, see `collectData()` below for more details. + std::vector data; data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1, std::multiplies())); collectData(lit, data); - // FIXME: using a tensor type is a HACK here. - // Can we do differently without registering a dialect? Using a string blob? - mlir::Type elementType = mlir::FloatType::getF64(&context); - auto dataType = builder->getTensorType(lit.getDims(), elementType); + // The type of this attribute is tensor of 64-bit floating-point with the + // shape of the literal. + mlir::Type elementType = builder.getF64Type(); + auto dataType = builder.getTensorType(lit.getDims(), elementType); - // This is the actual attribute that actually hold the list of values for - // this array literal. - auto dataAttribute = builder->getDenseElementsAttr(dataType, data) - .cast(); + // This is the actual attribute that holds the list of values for this + // tensor literal. + auto dataAttribute = + mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(data)); - // Build the MLIR op `toy.constant`, only boilerplate below. - return builder->create(location, lit.getDims(), dataAttribute) - .getResult(); + // Build the MLIR op `toy.constant`. + return builder.create(loc(lit.loc()), type, dataAttribute); } - // Recursive helper function to accumulate the data that compose an array - // literal. It flattens the nested structure in the supplied vector. For - // example with this array: - // [[1, 2], [3, 4]] - // we will generate: - // [ 1, 2, 3, 4 ] - // Individual numbers are wrapped in a light wrapper `mlir::FloatAttr`. - // Attributes are the way MLIR attaches constant to operations and functions. - void collectData(ExprAST &expr, std::vector &data) { + /// Recursive helper function to accumulate the data that compose an array + /// literal. It flattens the nested structure in the supplied vector. For + /// example with this array: + /// [[1, 2], [3, 4]] + /// we will generate: + /// [ 1, 2, 3, 4 ] + /// Individual numbers are represented as doubles. + /// Attributes are the way MLIR attaches constant to operations. + void collectData(ExprAST &expr, std::vector &data) { if (auto *lit = dyn_cast(&expr)) { for (auto &value : lit->getValues()) collectData(*value, data); return; } + assert(isa(expr) && "expected literal or number expr"); - mlir::Type elementType = mlir::FloatType::getF64(&context); - auto attr = mlir::FloatAttr::getChecked( - elementType, cast(expr).getValue(), loc(expr.loc())); - data.push_back(attr); + data.push_back(cast(expr).getValue()); } - // Emit a call expression. It emits specific operations for the `transpose` - // builtin. Other identifiers are assumed to be user-defined functions. + /// Emit a call expression. It emits specific operations for the `transpose` + /// builtin. Other identifiers are assumed to be user-defined functions. mlir::Value *mlirGen(CallExprAST &call) { + llvm::StringRef callee = call.getCallee(); auto location = loc(call.loc()); - std::string callee = call.getCallee(); - if (callee == "transpose") { - if (call.getArgs().size() != 1) { - emitError(location, "MLIR codegen encountered an error: toy.transpose " - "does not accept multiple arguments"); - return nullptr; - } - mlir::Value *arg = mlirGen(*call.getArgs()[0]); - return builder->create(location, arg).getResult(); - } - // Codegen the operands first + // Codegen the operands first. SmallVector operands; for (auto &expr : call.getArgs()) { auto *arg = mlirGen(*expr); @@ -342,34 +326,41 @@ private: return nullptr; operands.push_back(arg); } - // Calls to user-defined function are mapped to a custom call that takes - // the callee name as an attribute. - return builder->create(location, call.getCallee(), operands) - .getResult(); + + // Builting calls have their custom operation, meaning this is a + // straightforward emission. + if (callee == "transpose") { + if (call.getArgs().size() != 1) { + emitError(location, "MLIR codegen encountered an error: toy.transpose " + "does not accept multiple arguments"); + return nullptr; + } + return builder.create(location, operands[0]); + } + + // Otherwise this is a call to a user-defined function. Calls to ser-defined + // functions are mapped to a custom call that takes the callee name as an + // attribute. + return builder.create(location, callee, operands); } - // Emit a call expression. It emits specific operations for two builtins: - // transpose(x) and print(x). Other identifiers are assumed to be user-defined - // functions. Return false on failure. - bool mlirGen(PrintExprAST &call) { + /// Emit a print expression. It emits specific operations for two builtins: + /// transpose(x) and print(x). + mlir::LogicalResult mlirGen(PrintExprAST &call) { auto *arg = mlirGen(*call.getArg()); if (!arg) - return false; - auto location = loc(call.loc()); - builder->create(location, arg); - return true; + return mlir::failure(); + + builder.create(loc(call.loc()), arg); + return mlir::success(); } - // Emit a constant for a single number (FIXME: semantic? broadcast?) + /// Emit a constant for a single number (FIXME: semantic? broadcast?) mlir::Value *mlirGen(NumberExprAST &num) { - auto location = loc(num.loc()); - mlir::Type elementType = mlir::FloatType::getF64(&context); - auto attr = mlir::FloatAttr::getChecked(elementType, num.getValue(), - loc(num.loc())); - return builder->create(location, attr).getResult(); + return builder.create(loc(num.loc()), num.getValue()); } - // Dispatch codegen for the right expression subclass using RTTI. + /// Dispatch codegen for the right expression subclass using RTTI. mlir::Value *mlirGen(ExprAST &expr) { switch (expr.getKind()) { case toy::ExprAST::Expr_BinOp: @@ -390,77 +381,75 @@ private: } } - // Handle a variable declaration, we'll codegen the expression that forms the - // initializer and record the value in the symbol table before returning it. - // Future expressions will be able to reference this variable through symbol - // table lookup. + /// Handle a variable declaration, we'll codegen the expression that forms the + /// initializer and record the value in the symbol table before returning it. + /// Future expressions will be able to reference this variable through symbol + /// table lookup. mlir::Value *mlirGen(VarDeclExprAST &vardecl) { - mlir::Value *value = nullptr; - auto location = loc(vardecl.loc()); - if (auto init = vardecl.getInitVal()) { - value = mlirGen(*init); - if (!value) - return nullptr; - // We have the initializer value, but in case the variable was declared - // with specific shape, we emit a "reshape" operation. It will get - // optimized out later as needed. - if (!vardecl.getType().shape.empty()) { - value = builder - ->create( - location, value, - getType(vardecl.getType()).cast()) - .getResult(); - } - } else { + auto init = vardecl.getInitVal(); + if (!init) { emitError(loc(vardecl.loc()), - "missing initializer in variable declaration"); + "Missing initializer in variable declaration"); return nullptr; } - // Register the value in the symbol table - declare(vardecl.getName(), value); + + mlir::Value *value = mlirGen(*init); + if (!value) + return nullptr; + + // We have the initializer value, but in case the variable was declared + // with specific shape, we emit a "reshape" operation. It will get + // optimized out later as needed. + if (!vardecl.getType().shape.empty()) { + value = builder.create(loc(vardecl.loc()), + getType(vardecl.getType()), value); + } + + // Register the value in the symbol table. + if (failed(declare(vardecl.getName(), value))) + return nullptr; return value; } - /// Codegen a list of expression, return false if one of them hit an error. - bool mlirGen(ExprASTList &blockAST) { - ScopedHashTableScope var_scope(symbolTable); + /// Codegen a list of expression, return failure if one of them hit an error. + mlir::LogicalResult mlirGen(ExprASTList &blockAST) { + ScopedHashTableScope var_scope(symbolTable); for (auto &expr : blockAST) { // Specific handling for variable declarations, return statement, and // print. These can only appear in block list and not in nested // expressions. if (auto *vardecl = dyn_cast(expr.get())) { if (!mlirGen(*vardecl)) - return false; + return mlir::failure(); continue; } - if (auto *ret = dyn_cast(expr.get())) { - if (!mlirGen(*ret)) - return false; - return true; - } + if (auto *ret = dyn_cast(expr.get())) + return mlirGen(*ret); if (auto *print = dyn_cast(expr.get())) { - if (!mlirGen(*print)) - return false; + if (mlir::failed(mlirGen(*print))) + return mlir::success(); continue; } + // Generic expression dispatch codegen. if (!mlirGen(*expr)) - return false; + return mlir::failure(); } - return true; + return mlir::success(); } - /// Build a type from a list of shape dimensions. Types are `array` followed - /// by an optional dimension list, example: array<2, 2> - /// They are wrapped in a `toy` dialect (see next chapter) and get printed: - /// !toy.array<2, 2> - template mlir::Type getType(T shape) { - SmallVector shape64(shape.begin(), shape.end()); - return ToyArrayType::get(&context, shape64); + /// Build a tensor type from a list of shape dimensions. + mlir::Type getType(ArrayRef shape) { + // If the shape is empty, then this type is unranked. + if (shape.empty()) + return builder.getTensorType(builder.getF64Type()); + + // Otherwise, we use the given shape. + return builder.getTensorType(shape, builder.getF64Type()); } - /// Build an MLIR type from a Toy AST variable type - /// (forward to the generic getType(T) above). + /// Build an MLIR type from a Toy AST variable type (forward to the generic + /// getType above). mlir::Type getType(const VarType &type) { return getType(type.shape); } }; diff --git a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp index 1600b99ec013..b8b091a62c5e 100644 --- a/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp @@ -1,4 +1,4 @@ -//===- ShapeInferencePass.cpp - Toy Shape Inference / Func Specialization -===// +//===- ShapeInferencePass.cpp - Shape Inference ---------------------------===// // // Copyright 2019 The MLIR Authors. // @@ -15,22 +15,14 @@ // limitations under the License. // ============================================================================= // -// This file implements a Module level pass performing interprocedural +// This file implements a Function level pass performing interprocedural // propagation of array shapes through function specialization. // //===----------------------------------------------------------------------===// -#include "toy/Dialect.h" - -#include "mlir/Analysis/Verifier.h" -#include "mlir/Dialect/StandardOps/Ops.h" -#include "mlir/IR/BlockAndValueMapping.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/StandardTypes.h" #include "mlir/Pass/Pass.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Support/STLExtras.h" -#include "llvm/ADT/DenseSet.h" +#include "toy/Dialect.h" +#include "toy/Passes.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringSet.h" @@ -39,48 +31,26 @@ #include "llvm/Support/raw_ostream.h" #include -#define DEBUG_TYPE "toy-shape-inference" +#define DEBUG_TYPE "shape-inference" -using namespace toy; using llvm::MutableArrayRef; +using llvm::raw_ostream; using llvm::SmallVector; using llvm::SmallVectorImpl; using llvm::StringRef; using llvm::Twine; - -/// Create a mangled name for function specialization. We will simply append the -/// shape of the arguments to the function name. For example, calling -/// -/// "toy.generic_call"(%1, %3) {callee: "foo"} -/// : (!toy.array<2, 3>, !toy.array<2, 3>) -> !toy<"array"> -/// -/// would be mangled foo_2x3_2x3. This mangling isn't robust as the user could -/// have provided a function with a similar name, but we will claim this as a -/// feature: this allows the user to provide custom specializations! -static std::string mangle(StringRef funcName, - MutableArrayRef operands) { - std::string mangledName; - mangledName.reserve(funcName.size() + operands.size() * 6); - mangledName = funcName; - for (auto &operand : operands) { - auto arrayTy = operand.get()->getType().cast(); - mangledName += "_"; - mlir::interleave( - arrayTy.getShape(), - [&](int64_t dim) { mangledName += Twine(dim).str(); }, - [&]() { mangledName += "x"; }); - } - return mangledName; -} +using namespace mlir; namespace { -/// The ShapeInferencePass is a ModulePass: it will run on the Module as a -/// whole. MLIR also supports FunctionPass which are restricted to modify a -/// single function at a time. This pass couldn't be a function pass due the -/// nature of its interprocedural transformations. +// clang-format off +#include "toy/ShapeInferenceOpInterfaces.h.inc" +#include "toy/ShapeInferenceOpInterfaces.cpp.inc" + +/// The ShapeInferencePass is a FunctionPass that performs intra-procedural +/// shape inference. /// -/// The algorithm has two levels, first intra-procedurally: +/// Algorithm: /// /// 1) Build a worklist containing all the operations that are returning /// a generic Toy array: these are the operations that need shape @@ -94,132 +64,25 @@ namespace { /// 3) If the worklist is empty, the algorithm succeeded and we infer the /// return type for the function from the return operation. /// -/// There is a twist though: when a call to a generic function is encountered, -/// shape inference requires the return type of the callee to be inferred first. -/// At this point we need to run specialize the callee by cloning it. Here is -/// the inter-procedural flow: -/// -/// 1) Keep a worklist of function to process. Start with function "main". -/// 2) While the worklist isn't empty: -/// a) Take the last inserted function in the worklist. -/// b) Run the intra-procedural shape inference on this function. -/// c) If the intra-procedural shape inference can't complete, it returns -/// a Function that needs to be inferred first. In this case, queue this -/// new function and continue. Otherwise the inference succeeded and we -/// can pop from the queue. -/// -class ShapeInferencePass : public mlir::ModulePass { +class ShapeInferencePass : public mlir::FunctionPass { public: - // One entry in the inter-procedural worklist. It keeps track of the - // function to process, the mangled name for this specialization, and the - // types of the arguments on which to specialize. - struct FunctionToSpecialize { - mlir::FuncOp function; - std::string mangledName; - SmallVector argumentsType; - }; - - void runOnModule() override { - auto module = getModule(); - auto main = module.lookupSymbol("main"); - if (!main) { - emitError(mlir::UnknownLoc::get(module.getContext()), - "shape inference failed: can't find a main function\n"); - signalPassFailure(); - return; - } - - /// Inter-procedural loop, initialize with `main` and iterate until we - /// successfully infer the full reachable call-graph from main. - SmallVector worklist; - worklist.push_back({main, "", {}}); - while (!worklist.empty()) { - if (failed(specialize(worklist))) - return; - } - - // Delete any generic function left - // FIXME: we may want this as a separate pass. - for (mlir::FuncOp function : - llvm::make_early_inc_range(module.getOps())) { - if (auto genericAttr = - function.getAttrOfType("toy.generic")) { - if (genericAttr.getValue()) - function.erase(); - } + bool returnsGenericArray(Operation *op) { + if (op->getNumResults() == 1) { + if (!op->getResult(0)->getType().isa()) + return true; } + return false; } - /// Run inference on a function. If a mangledName is provided, we need to - /// specialize the function: to this end clone it first. - mlir::LogicalResult - specialize(SmallVectorImpl &funcWorklist) { - FunctionToSpecialize &functionToSpecialize = funcWorklist.back(); - mlir::FuncOp f = functionToSpecialize.function; - - // Check if cloning for specialization is needed (usually anything but main) - // We will create a new function with the concrete types for the parameters - // and clone the body into it. - if (!functionToSpecialize.mangledName.empty()) { - if (getModule().lookupSymbol( - functionToSpecialize.mangledName)) { - funcWorklist.pop_back(); - // Function already specialized, move on. - return mlir::success(); - } - // Create a new function with a generic array return type, it will be - // updated when the inference for the function body completes. - auto type = mlir::FunctionType::get(functionToSpecialize.argumentsType, - {ToyArrayType::get(&getContext())}, - &getContext()); - auto newFunction = - mlir::FuncOp::create(f.getLoc(), functionToSpecialize.mangledName, - type, f.getDialectAttrs()); - getModule().push_back(newFunction); - - // Clone the function body - mlir::BlockAndValueMapping mapper; - f.cloneInto(newFunction, mapper); - LLVM_DEBUG({ - llvm::dbgs() << "====== Cloned : \n"; - f.dump(); - llvm::dbgs() << "====== Into : \n"; - newFunction.dump(); - }); - f = newFunction; - f.setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext())); - // Remap the entry-block arguments - // FIXME: this seems like a bug in `cloneInto()` above? - auto &entryBlock = f.getBlocks().front(); - int blockArgSize = entryBlock.getArguments().size(); - assert(blockArgSize == static_cast(f.getType().getInputs().size())); - entryBlock.addArguments(f.getType().getInputs()); - auto argList = entryBlock.getArguments(); - for (int argNum = 0; argNum < blockArgSize; ++argNum) { - argList[0]->replaceAllUsesWith(argList[blockArgSize]); - entryBlock.eraseArgument(0); - } - assert(succeeded(mlir::verify(f))); - } - LLVM_DEBUG(llvm::dbgs() - << "Run shape inference on : '" << f.getName() << "'\n"); - - auto *toyDialect = getContext().getRegisteredDialect("toy"); - if (!toyDialect) { - emitError(mlir::UnknownLoc::get(&getContext()), - "Toy dialect is not registered"); - signalPassFailure(); - return mlir::failure(); - } + void runOnFunction() override { + auto f = getFunction(); // Populate the worklist with the operations that need shape inference: - // these are the Toy operations that return a generic array. + // these are operations that return a generic array. llvm::SmallPtrSet opWorklist; f.walk([&](mlir::Operation *op) { - if (op->getDialect() == toyDialect) { - if (op->getNumResults() == 1 && - op->getResult(0)->getType().cast().isGeneric()) - opWorklist.insert(op); + if (returnsGenericArray(op)) { + opWorklist.insert(op); } }); @@ -228,154 +91,31 @@ public: while (!opWorklist.empty()) { // Find the next operation ready for inference, that is an operation // with all operands already resolved (non-generic). - auto nextop = llvm::find_if(opWorklist, [](mlir::Operation *op) { - return llvm::all_of(op->getOperandTypes(), [](mlir::Type ty) { - return !ty.cast().isGeneric(); - }); + auto nextop = llvm::find_if(opWorklist, [this](Operation *op) { + return this->returnsGenericArray(op); }); + if (nextop == opWorklist.end()) break; // failure: no operations can be inferred. - mlir::Operation *op = *nextop; + Operation *op = *nextop; opWorklist.erase(op); LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n"); - - // The add operation is trivial: propagate the input type as is. - if (auto addOp = llvm::dyn_cast(op)) { - op->getResult(0)->setType(op->getOperand(0)->getType()); - continue; - } - - // Transpose is easy: just invert the dimensions. - if (auto transpose = llvm::dyn_cast(op)) { - SmallVector dims; - auto arrayTy = transpose.getOperand()->getType().cast(); - dims.insert(dims.end(), arrayTy.getShape().begin(), - arrayTy.getShape().end()); - transpose.getResult()->setType(ToyArrayType::get(&getContext(), dims)); - continue; - } - - // Multiplication is a bit trickier, handle rank 1 as dot product and rank - // 2 as matrix multiplications. - // We need to be careful about rank mismatch here: the verifier could - // catch it but shape inference earlier in the pass could generate an - // invalid IR (from an invalid Toy input of course) and we wouldn't want - // to crash here. - if (auto mulOp = llvm::dyn_cast(op)) { - auto lhs = mulOp.getLHS()->getType().cast(); - auto rhs = mulOp.getRHS()->getType().cast(); - auto lhsRank = lhs.getShape().size(); - auto rhsRank = rhs.getShape().size(); - if (lhsRank != rhsRank) { - return mulOp.emitOpError( - "shape mismatch: LHS and RHS must have the same " - "rank for multiplication, got " + - Twine(lhsRank) + " vs " + Twine(lhsRank)); - } - SmallVector dims; - if (lhsRank == 1) { - // dot product, result shape is <1> - dims.push_back(1); - } else if (lhsRank != 2) { - return op->emitOpError( - "shape mismatch: expect rank 1 or 2 for mul operands, got " + - Twine(lhsRank)); - } else { - dims.push_back(lhs.getShape()[0]); - dims.push_back(rhs.getShape()[1]); - } - op->getResult(0)->setType(ToyArrayType::get(&getContext(), dims)); - continue; - } - - // Process calls: lookup the callee after mangling the name with the - // argument shapes. If the callee does not exist, we stop the inference - // for this function, queue the callee in the inter-procedural work list, - // and return. The current function stays in the work list and will - // restart after the callee is processed. - if (auto callOp = llvm::dyn_cast(op)) { - auto calleeName = callOp.getCalleeName(); - auto callee = getModule().lookupSymbol(calleeName); - if (!callee) { - f.emitError("shape inference failed, call to unknown '") - << calleeName << "'"; - signalPassFailure(); - return mlir::failure(); - } - auto mangledName = mangle(calleeName, op->getOpOperands()); - LLVM_DEBUG(llvm::dbgs() << "Found callee to infer: '" << calleeName - << "', mangled: '" << mangledName << "'\n"); - auto mangledCallee = - getModule().lookupSymbol(mangledName); - if (!mangledCallee) { - // Can't find the target, this is where we queue the request for the - // callee and stop the inference for the current function now. - funcWorklist.push_back({callee, std::move(mangledName), - llvm::to_vector<4>(op->getOperandTypes())}); - return mlir::success(); - } - // Found a specialized callee! Let's turn this into a normal call - // operation. - SmallVector operands(op->getOperands()); - mlir::OpBuilder builder(op); - auto newCall = - builder.create(op->getLoc(), mangledCallee, operands); - if (newCall.getNumResults()) { - op->getResult(0)->replaceAllUsesWith(newCall.getResult(0)); - op->erase(); - continue; - } - } + auto shapeOp = dyn_cast(op); + shapeOp.inferShapes(); } - // Done with inference on this function, removing it from the worklist. - funcWorklist.pop_back(); - // Mark the function as non-generic now that inference has succeeded - f.setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext())); - // If the operation worklist isn't empty, this indicates a failure. if (!opWorklist.empty()) { - std::string str; - llvm::raw_string_ostream errorMsg(str); - errorMsg << "shape inference failed, " << opWorklist.size() - << " operations couldn't be inferred\n"; - for (auto *ope : opWorklist) - errorMsg << " - " << *ope << "\n"; - f.emitError(errorMsg.str()); signalPassFailure(); - return mlir::failure(); + auto diag = f.emitError("Shape inference failed, ") + << opWorklist.size() << " operations couldn't be inferred\n"; } - - // Finally, update the return type of the function based on the argument to - // the return operation. - for (auto &block : f.getBlocks()) { - auto ret = llvm::cast(block.getTerminator()); - if (!ret) - continue; - if (ret.getNumOperands() && - f.getType().getResult(0) == ret.getOperand()->getType()) - // type match, we're done - break; - SmallVector retTy; - if (ret.getNumOperands()) - retTy.push_back(ret.getOperand()->getType()); - std::vector argumentsType; - for (auto arg : f.getArguments()) - argumentsType.push_back(arg->getType()); - auto newType = - mlir::FunctionType::get(argumentsType, retTy, &getContext()); - f.setType(newType); - assert(succeeded(mlir::verify(f))); - break; - } - return mlir::success(); } }; } // end anonymous namespace -namespace toy { -std::unique_ptr createShapeInferencePass() { +/// Create a Shape Inference pass. +std::unique_ptr mlir::toy::createShapeInferencePass() { return std::make_unique(); } -} // namespace toy diff --git a/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp index b89cb85ff06d..1b9dcd202919 100644 --- a/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp @@ -15,24 +15,25 @@ // limitations under the License. // ============================================================================= // -// This file implements a simple combiner for optimizing pattern in the Toy -// dialect. +// This file implements a set of simple combiners for optimizing operations in +// the Toy dialect. // //===----------------------------------------------------------------------===// -#include "toy/Dialect.h" - #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" -#include "mlir/IR/StandardTypes.h" - +#include "toy/Dialect.h" #include - -namespace toy { +using namespace mlir; +using namespace toy; namespace { +/// Include the patterns defined in the Declarative Rewrite framework. +#include "ToyCombine.inc" +} // end anonymous namespace -/// Fold transpose(transpose(x) -> transpose(x) +/// This is an example of a c++ rewrite pattern for the TransposeOp. It +/// optimizes the following scenario: transpose(transpose(x)) -> transpose(x) struct SimplifyRedundantTranspose : public mlir::OpRewritePattern { /// We register this pattern to match every toy.transpose in the IR. /// The "benefit" is used by the framework to order the patterns and process @@ -40,9 +41,9 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern { SimplifyRedundantTranspose(mlir::MLIRContext *context) : OpRewritePattern(context, /*benefit=*/1) {} - /// This method is attempting to match a pattern and rewrite it. The rewriter - /// argument is the orchestrator of the sequence of rewrites. It is expected - /// to interact with it to perform any changes to the IR from here. + /// This method attempts to match a pattern and rewrite it. The rewriter + /// argument is the orchestrator of the sequence of rewrites. The pattern is + /// expected to interact with it to perform any changes to the IR from here. mlir::PatternMatchResult matchAndRewrite(TransposeOp op, mlir::PatternRewriter &rewriter) const override { @@ -50,106 +51,28 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern { mlir::Value *transposeInput = op.getOperand(); TransposeOp transposeInputOp = llvm::dyn_cast_or_null(transposeInput->getDefiningOp()); + // If the input is defined by another Transpose, bingo! if (!transposeInputOp) return matchFailure(); - // Use the rewriter to perform the replacement + // Use the rewriter to perform the replacement. rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp}); return matchSuccess(); } }; -/// Fold reshape(constant(x)) -> constant(x'), with x' being reshaped in place. -struct SimplifyReshapeConstant : public mlir::OpRewritePattern { - using mlir::OpRewritePattern::OpRewritePattern; - - mlir::PatternMatchResult - matchAndRewrite(ReshapeOp reshape, - mlir::PatternRewriter &rewriter) const override { - // Look through the input of the current reshape. - ConstantOp constantOp = llvm::dyn_cast_or_null( - reshape.getOperand()->getDefiningOp()); - // If the input is defined by another constant, bingo! - if (!constantOp) - return matchFailure(); - - auto reshapeType = reshape.getType().cast(); - if (auto valueAttr = - constantOp.getAttrOfType("value")) { - // FIXME Check matching of element count! - // auto oldType = constantOp.getType(); - auto newType = rewriter.getTensorType( - reshapeType.getShape(), valueAttr.getType().getElementType()); - auto newAttr = valueAttr.reshape(newType); - rewriter.replaceOpWithNewOp(reshape, reshapeType.getShape(), - newAttr); - } else if (auto valueAttr = - constantOp.getAttrOfType("value")) { - // Broadcast - auto dataSize = std::accumulate(reshapeType.getShape().begin(), - reshapeType.getShape().end(), 1, - std::multiplies()); - std::vector data(dataSize, valueAttr); - auto tensorTy = rewriter.getTensorType(reshapeType.getShape(), - reshapeType.getElementType()); - auto newAttr = mlir::DenseElementsAttr::get(tensorTy, data); - rewriter.replaceOpWithNewOp(reshape, reshapeType.getShape(), - newAttr); - } else { - llvm_unreachable("Unsupported Constant format"); - } - return matchSuccess(); - } -}; - -/// Fold reshape(reshape(x)) -> reshape(x) -struct SimplifyReshapeReshape : public mlir::OpRewritePattern { - using mlir::OpRewritePattern::OpRewritePattern; - - mlir::PatternMatchResult - matchAndRewrite(ReshapeOp op, - mlir::PatternRewriter &rewriter) const override { - // Look through the input of the current reshape. - mlir::Value *reshapeInput = op.getOperand(); - - // If the input is defined by another reshape, bingo! - if (!matchPattern(reshapeInput, mlir::m_Op())) - return matchFailure(); - - // Use the rewriter to perform the replacement - rewriter.replaceOp(op, {reshapeInput}); - return matchSuccess(); - } -}; - -/// Fold reshape(x)) -> x, when input type matches output type -struct SimplifyNullReshape : public mlir::OpRewritePattern { - using mlir::OpRewritePattern::OpRewritePattern; - - mlir::PatternMatchResult - matchAndRewrite(ReshapeOp op, - mlir::PatternRewriter &rewriter) const override { - if (op.getOperand()->getType() != op.getType()) - return matchFailure(); - rewriter.replaceOp(op, {op.getOperand()}); - return matchSuccess(); - } -}; - -} // end anonymous namespace. - -// Register our patterns for rewrite by the Canonicalization framework. -void TransposeOp::getCanonicalizationPatterns( - mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) { +/// Register our patterns as "canonicalization" patterns on the TransposeOp so +/// that they can be picked up by the Canonicalization framework. +void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { results.insert(context); } -// Register our patterns for rewrite by the Canonicalization framework. -void ReshapeOp::getCanonicalizationPatterns( - mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) { - results.insert(context); +/// Register our patterns as "canonicalization" patterns on the ReshapeOp so +/// that they can be picked up by the Canonicalization framework. +void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); } - -} // namespace toy diff --git a/mlir/examples/toy/Ch4/mlir/ToyCombine.td b/mlir/examples/toy/Ch4/mlir/ToyCombine.td new file mode 100644 index 000000000000..97b9be4c3530 --- /dev/null +++ b/mlir/examples/toy/Ch4/mlir/ToyCombine.td @@ -0,0 +1,73 @@ +//===- ToyCombine.td - Pattern Match Optimizations for Toy -*- tablegen -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// Defines language-specific pattern match optimizations for Toy using +// Declarative Rewrite Rules (DRR) specified using TableGen records. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_COMBINE +#define TOY_COMBINE + +#ifndef OP_BASE +include "toy/Ops.td" +#endif // OP_BASE + +/// Note: The DRR definition used for defining patterns is shown below: +/// +/// class Pattern< +/// dag sourcePattern, list resultPatterns, +/// list additionalConstraints = [], +/// dag benefitsAdded = (addBenefit 0) +/// >; + +//===----------------------------------------------------------------------===// +// Basic Pattern-Match and Rewrite +//===----------------------------------------------------------------------===// + +// Reshape(Reshape(x)) = x +def ReshapeReshapeOptPattern : Pat<(ReshapeOp(ReshapeOp $arg)), + (ReshapeOp $arg)>; + +//===----------------------------------------------------------------------===// +// Pattern-Match and Rewrite using Native Code Call +//===----------------------------------------------------------------------===// + +// Native Code Calls may be used for more complex transformations using inline +// C++ and C++ helper functions. + +// Reshape(Constant(x)) = x' +def ReshapeConstant : + NativeCodeCall<"$0.reshape(($1->getType()).cast())">; +def FoldConstantReshapeOptPattern : Pat< + (ReshapeOp:$res (ConstantOp $arg)), + (ConstantOp (ReshapeConstant $arg, $res))>; + +//===----------------------------------------------------------------------===// +// Pattern-Match and Rewrite with Constraints +//===----------------------------------------------------------------------===// + +// DRR allows for constraint checking when the transformation is conditional +// on operand properties. + +// Reshape(x) = x, where input and output shapes are identical +def TypesAreIdentical : ConstraintgetType() == $1->getType()">>; +def RedundantReshapeOptPattern : Pat< + (ReshapeOp:$res $arg), (replaceWithValue $arg), + [(TypesAreIdentical $res, $arg)]>; + +#endif // TOY_COMBINE diff --git a/mlir/examples/toy/Ch4/mlir/ToyDialect.cpp b/mlir/examples/toy/Ch4/mlir/ToyDialect.cpp deleted file mode 100644 index f77754e368ff..000000000000 --- a/mlir/examples/toy/Ch4/mlir/ToyDialect.cpp +++ /dev/null @@ -1,387 +0,0 @@ -//===- ToyDialect.cpp - Toy IR Dialect registration in MLIR ---------------===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// This file implements the dialect for the Toy IR: custom type parsing and -// operation verification. -// -//===----------------------------------------------------------------------===// - -#include "toy/Dialect.h" - -#include "mlir/IR/Builders.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/Support/STLExtras.h" -#include "llvm/ADT/iterator_range.h" -#include "llvm/Support/ErrorHandling.h" -#include "llvm/Support/Regex.h" -#include "llvm/Support/raw_ostream.h" - -using llvm::ArrayRef; -using llvm::raw_ostream; -using llvm::raw_string_ostream; -using llvm::SmallVector; -using llvm::StringRef; -using llvm::Twine; - -namespace toy { -namespace detail { - -/// This class holds the implementation of the ToyArrayType. -/// It is intended to be uniqued based on its content and owned by the context. -struct ToyArrayTypeStorage : public mlir::TypeStorage { - /// This defines how we unique this type in the context: our key contains - /// only the shape, a more complex type would have multiple entries in the - /// tuple here. - /// The element of the tuples usually matches 1-1 the arguments from the - /// public `get()` method arguments from the facade. - using KeyTy = std::tuple>; - static unsigned hashKey(const KeyTy &key) { - return llvm::hash_combine(std::get<0>(key)); - } - /// When the key hash hits an existing type, we compare the shape themselves - /// to confirm we have the right type. - bool operator==(const KeyTy &key) const { return key == KeyTy(getShape()); } - - /// This is a factory method to create our type storage. It is only - /// invoked after looking up the type in the context using the key and not - /// finding it. - static ToyArrayTypeStorage *construct(mlir::TypeStorageAllocator &allocator, - const KeyTy &key) { - // Copy the shape array into the bumpptr allocator owned by the context. - ArrayRef shape = allocator.copyInto(std::get<0>(key)); - - // Allocate the instance for the ToyArrayTypeStorage itself - auto *storage = allocator.allocate(); - // Initialize the instance using placement new. - return new (storage) ToyArrayTypeStorage(shape); - } - - ArrayRef getShape() const { return shape; } - -private: - ArrayRef shape; - - /// Constructor is only invoked from the `construct()` method above. - ToyArrayTypeStorage(ArrayRef shape) : shape(shape) {} -}; - -} // namespace detail - -mlir::Type ToyArrayType::getElementType() { - return mlir::FloatType::getF64(getContext()); -} - -ToyArrayType ToyArrayType::get(mlir::MLIRContext *context, - ArrayRef shape) { - return Base::get(context, ToyTypeKind::TOY_ARRAY, shape); -} - -ArrayRef ToyArrayType::getShape() { return getImpl()->getShape(); } - -/// Dialect creation, the instance will be owned by the context. This is the -/// point of registration of custom types and operations for the dialect. -ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) { - addOperations(); - addTypes(); -} - -/// Parse a type registered to this dialect, we expect only Toy arrays. -mlir::Type ToyDialect::parseType(StringRef tyData, mlir::Location loc) const { - // Sanity check: we only support array or array<...> - if (!tyData.startswith("array")) { - emitError(loc, "invalid Toy type '" + tyData + "', array expected"); - return nullptr; - } - // Drop the "array" prefix from the type name, we expect either an empty - // string or just the shape. - tyData = tyData.drop_front(StringRef("array").size()); - // This is the generic array case without shape, early return it. - if (tyData.empty()) - return ToyArrayType::get(getContext()); - - // Use a regex to parse the shape (for efficient we should store this regex in - // the dialect itself). - SmallVector matches; - auto shapeRegex = llvm::Regex("^<([0-9]+)(, ([0-9]+))*>$"); - if (!shapeRegex.match(tyData, &matches)) { - emitError(loc, "invalid toy array shape '" + tyData + "'"); - return nullptr; - } - SmallVector shape; - // Iterate through the captures, skip the first one which is the full string. - for (auto dimStr : - llvm::make_range(std::next(matches.begin()), matches.end())) { - if (dimStr.startswith(",")) - continue; // POSIX misses non-capturing groups. - if (dimStr.empty()) - continue; // '*' makes it an optional group capture - // Convert the capture to an integer - unsigned long long dim; - if (getAsUnsignedInteger(dimStr, /* Radix = */ 10, dim)) { - emitError(loc, "couldn't parse dimension as integer, matched: " + dimStr); - return mlir::Type(); - } - shape.push_back(dim); - } - // Finally we collected all the dimensions in the shape, - // create the array type. - return ToyArrayType::get(getContext(), shape); -} - -/// Print a Toy array type, for example `array<2, 3, 4>` -void ToyDialect::printType(mlir::Type type, raw_ostream &os) const { - auto arrayTy = type.dyn_cast(); - if (!arrayTy) { - os << "unknown toy type"; - return; - } - os << "array"; - if (!arrayTy.getShape().empty()) { - os << "<"; - mlir::interleaveComma(arrayTy.getShape(), os); - os << ">"; - } -} - -//////////////////////////////////////////////////////////////////////////////// -//////////////////// Custom Operations for the Dialect ///////////////////////// -//////////////////////////////////////////////////////////////////////////////// - -/// Helper to verify that the result of an operation is a Toy array type. -template static mlir::LogicalResult verifyToyReturnArray(T *op) { - if (!op->getResult()->getType().template isa()) { - std::string msg; - raw_string_ostream os(msg); - os << "expects a Toy Array for its argument, got " - << op->getResult()->getType(); - return op->emitOpError(os.str()); - } - return mlir::success(); -} - -/// Helper to verify that the two operands of a binary operation are Toy -/// arrays.. -template static mlir::LogicalResult verifyToyBinOperands(T *op) { - if (!op->getOperand(0)->getType().template isa()) { - std::string msg; - raw_string_ostream os(msg); - os << "expects a Toy Array for its LHS, got " - << op->getOperand(0)->getType(); - return op->emitOpError(os.str()); - } - if (!op->getOperand(1)->getType().template isa()) { - std::string msg; - raw_string_ostream os(msg); - os << "expects a Toy Array for its LHS, got " - << op->getOperand(0)->getType(); - return op->emitOpError(os.str()); - } - return mlir::success(); -} - -/// Build a constant operation. -/// The builder is passed as an argument, so is the state that this method is -/// expected to fill in order to build the operation. -void ConstantOp::build(mlir::Builder *builder, mlir::OperationState &state, - ArrayRef shape, mlir::DenseElementsAttr value) { - state.types.push_back(ToyArrayType::get(builder->getContext(), shape)); - auto dataAttribute = builder->getNamedAttr("value", value); - state.attributes.push_back(dataAttribute); -} - -/// Build a constant operation. -/// The builder is passed as an argument, so is the state that this method is -/// expected to fill in order to build the operation. -void ConstantOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::FloatAttr value) { - // Broadcast and forward to the other build factory - mlir::Type elementType = mlir::FloatType::getF64(builder->getContext()); - auto dataType = builder->getTensorType({1}, elementType); - auto dataAttribute = builder->getDenseElementsAttr(dataType, {value}) - .cast(); - - ConstantOp::build(builder, state, {1}, dataAttribute); -} - -/// Verifier for constant operation. -mlir::LogicalResult ConstantOp::verify() { - // Ensure that the return type is a Toy array - if (failed(verifyToyReturnArray(this))) - return mlir::failure(); - - // We expect the constant itself to be stored as an attribute. - auto dataAttr = getAttr("value").dyn_cast(); - if (!dataAttr) { - return emitOpError( - "missing valid `value` DenseElementsAttribute on toy.constant()"); - } - auto attrType = dataAttr.getType().dyn_cast(); - if (!attrType) { - return emitOpError( - "missing valid `value` DenseElementsAttribute on toy.constant()"); - } - - // If the return type of the constant is not a generic array, the shape must - // match the shape of the attribute holding the data. - auto resultType = getResult()->getType().cast(); - if (!resultType.isGeneric()) { - if (attrType.getRank() != resultType.getRank()) { - return emitOpError("The rank of the toy.constant return type must match " - "the one of the attached value attribute: " + - Twine(attrType.getRank()) + - " != " + Twine(resultType.getRank())); - } - for (int dim = 0; dim < attrType.getRank(); ++dim) { - if (attrType.getShape()[dim] != resultType.getShape()[dim]) { - std::string msg; - raw_string_ostream os(msg); - return emitOpError( - "Shape mismatch between toy.constant return type and its " - "attribute at dimension " + - Twine(dim) + ": " + Twine(attrType.getShape()[dim]) + - " != " + Twine(resultType.getShape()[dim])); - } - } - } - return mlir::success(); -} - -void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState &state, - StringRef callee, ArrayRef arguments) { - // Generic call always returns a generic ToyArray initially - state.types.push_back(ToyArrayType::get(builder->getContext())); - state.operands.assign(arguments.begin(), arguments.end()); - auto calleeAttr = builder->getStringAttr(callee); - state.attributes.push_back(builder->getNamedAttr("callee", calleeAttr)); -} - -mlir::LogicalResult GenericCallOp::verify() { - // Verify that every operand is a Toy Array - for (int opId = 0, num = getNumOperands(); opId < num; ++opId) { - if (!getOperand(opId)->getType().template isa()) { - std::string msg; - raw_string_ostream os(msg); - os << "expects a Toy Array for its " << opId << " operand, got " - << getOperand(opId)->getType(); - return emitOpError(os.str()); - } - } - return mlir::success(); -} - -/// Return the name of the callee. -StringRef GenericCallOp::getCalleeName() { - return getAttr("callee").cast().getValue(); -} - -template static mlir::LogicalResult verifyToySingleOperand(T *op) { - if (!op->getOperand()->getType().template isa()) { - std::string msg; - raw_string_ostream os(msg); - os << "expects a Toy Array for its argument, got " - << op->getOperand()->getType(); - return op->emitOpError(os.str()); - } - return mlir::success(); -} - -void ReturnOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *value) { - // Return does not return any value and has an optional single argument - if (value) - state.operands.push_back(value); -} - -mlir::LogicalResult ReturnOp::verify() { - if (getNumOperands() > 1) - return emitOpError("expects zero or one operand, got " + - Twine(getNumOperands())); - if (hasOperand() && failed(verifyToySingleOperand(this))) - return mlir::failure(); - return mlir::success(); -} - -void PrintOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *value) { - // Print does not return any value and has a single argument - state.operands.push_back(value); -} - -mlir::LogicalResult PrintOp::verify() { - if (failed(verifyToySingleOperand(this))) - return mlir::failure(); - return mlir::success(); -} - -void TransposeOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *value) { - state.types.push_back(ToyArrayType::get(builder->getContext())); - state.operands.push_back(value); -} - -mlir::LogicalResult TransposeOp::verify() { - if (failed(verifyToySingleOperand(this))) - return mlir::failure(); - return mlir::success(); -} - -void ReshapeOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *value, ToyArrayType reshapedType) { - state.types.push_back(reshapedType); - state.operands.push_back(value); -} - -mlir::LogicalResult ReshapeOp::verify() { - if (failed(verifyToySingleOperand(this))) - return mlir::failure(); - auto retTy = getResult()->getType().dyn_cast(); - if (!retTy) - return emitOpError("toy.reshape is expected to produce a Toy array"); - if (retTy.isGeneric()) - return emitOpError("toy.reshape is expected to produce a shaped Toy array, " - "got a generic one."); - return mlir::success(); -} - -void AddOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *lhs, mlir::Value *rhs) { - state.types.push_back(ToyArrayType::get(builder->getContext())); - state.operands.push_back(lhs); - state.operands.push_back(rhs); -} - -mlir::LogicalResult AddOp::verify() { - if (failed(verifyToyBinOperands(this))) - return mlir::failure(); - return mlir::success(); -} - -void MulOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *lhs, mlir::Value *rhs) { - state.types.push_back(ToyArrayType::get(builder->getContext())); - state.operands.push_back(lhs); - state.operands.push_back(rhs); -} - -mlir::LogicalResult MulOp::verify() { - if (failed(verifyToyBinOperands(this))) - return mlir::failure(); - return mlir::success(); -} - -} // namespace toy diff --git a/mlir/examples/toy/Ch4/toyc.cpp b/mlir/examples/toy/Ch4/toyc.cpp index 1c084e0918bf..6f75269b9be9 100644 --- a/mlir/examples/toy/Ch4/toyc.cpp +++ b/mlir/examples/toy/Ch4/toyc.cpp @@ -80,54 +80,63 @@ std::unique_ptr parseInputFile(llvm::StringRef filename) { return parser.ParseModule(); } -mlir::LogicalResult optimize(mlir::ModuleOp module) { - mlir::PassManager pm(module.getContext()); - pm.addPass(mlir::createCanonicalizerPass()); - pm.addPass(createShapeInferencePass()); - pm.addPass(mlir::createCanonicalizerPass()); - // Apply any generic pass manager command line options. - applyPassManagerCLOptions(pm); +int loadMLIR(mlir::MLIRContext &context, mlir::OwningModuleRef &module) { + // Handle '.toy' input to the compiler. + if (inputType != InputType::MLIR && + !llvm::StringRef(inputFilename).endswith(".mlir")) { + auto moduleAST = parseInputFile(inputFilename); + module = mlirGen(context, *moduleAST); + return !module ? 1 : 0; + } - return pm.run(module); + // Otherwise, the input is '.mlir'. + llvm::ErrorOr> fileOrErr = + llvm::MemoryBuffer::getFileOrSTDIN(inputFilename); + if (std::error_code EC = fileOrErr.getError()) { + llvm::errs() << "Could not open input file: " << EC.message() << "\n"; + return -1; + } + + // Parse the input mlir. + llvm::SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc()); + module = mlir::parseSourceFile(sourceMgr, &context); + if (!module) { + llvm::errs() << "Error can't load file " << inputFilename << "\n"; + return 3; + } + return 0; } int dumpMLIR() { - // Register our Dialect with MLIR - mlir::registerDialect(); + // Register our Dialect with MLIR. + mlir::registerDialect(); mlir::MLIRContext context; mlir::OwningModuleRef module; - if (inputType == InputType::MLIR || - llvm::StringRef(inputFilename).endswith(".mlir")) { - llvm::ErrorOr> fileOrErr = - llvm::MemoryBuffer::getFileOrSTDIN(inputFilename); - if (std::error_code EC = fileOrErr.getError()) { - llvm::errs() << "Could not open input file: " << EC.message() << "\n"; - return -1; - } - llvm::SourceMgr sourceMgr; - sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc()); - module = mlir::parseSourceFile(sourceMgr, &context); - if (!module) { - llvm::errs() << "Error can't load file " << inputFilename << "\n"; - return 3; - } - if (failed(mlir::verify(*module))) { - llvm::errs() << "Error verifying MLIR module\n"; - return 4; - } - } else { - auto moduleAST = parseInputFile(inputFilename); - module = mlirGen(context, *moduleAST); - } - if (!module) - return 1; + if (int error = loadMLIR(context, module)) + return error; + if (EnableOpt) { - if (failed(optimize(*module))) { - llvm::errs() << "Module optimization failed\n"; - return 7; - } + mlir::PassManager pm(&context); + // Apply any generic pass manager command line options and run the pipeline. + applyPassManagerCLOptions(pm); + + // Add a run of the canonicalizer to optimize the mlir module. + pm.addPass(mlir::createCanonicalizerPass()); + + // Inline all functions into main and then delete them. + pm.addPass(mlir::createInlinerPass()); + pm.addPass(mlir::toy::createDeadFunctionEliminationPass()); + + // Now that there is only one function, we can infer the shapes of each of + // the operations. + pm.addPass(mlir::toy::createShapeInferencePass()); + + if (mlir::failed(pm.run(*module))) + return 4; } + module->dump(); return 0; } diff --git a/mlir/g3doc/Tutorials/Toy/Ch-4.md b/mlir/g3doc/Tutorials/Toy/Ch-4.md index c064942c2ea2..363cd3da39fa 100644 --- a/mlir/g3doc/Tutorials/Toy/Ch-4.md +++ b/mlir/g3doc/Tutorials/Toy/Ch-4.md @@ -1,242 +1,118 @@ -# Chapter 4: High-level Language-Specific Analysis and Transformation +# Chapter 4: Using Interfaces -Creating a dialect that closely represents the semantics of an input language -enables analyses and transformations in MLIR that are generally performed on the -language AST. For example, `clang` has a fairly -[heavy mechanism](https://clang.llvm.org/doxygen/classclang_1_1TreeTransform.html) -for performing template instantiation in C++. +[Interfaces](../../Interfaces.md) provide a generic method for applying +transformations across dialects. We first describe how to leverage an existing +MLIR interface, and then walk through writing your own interface. -Another aspect is optimization. While some previous language specific -optimizations have been implemented in LLVM (like the -[ARC optimizer](http://llvm.org/doxygen/ObjCARCOpts_8cpp_source.html#l00468)), -it has been at the cost of relying on either adding enough concepts in LLVM, to -be able to embed the high-level semantics of the input, or using fragile -"best-effort" metadata to decorate the IR with the information needed for these -custom optimizations. +## Function Inlining -We show in this chapter how to leverage the Toy Dialect and its high-level -semantics to perform transformations that would be difficult in LLVM: first a -simple combine of two redundant operations, and second a full interprocedural -shape inference with function specialization. +In order to apply function inlining in the Toy dialect, we override the +DialectInlinerInterface in Toy, enable inlining and add special handling for the +return operation: -# Basic Optimization: Eliminate Redundant Transpose +```Toy(.cpp) +//===----------------------------------------------------------------------===// +// ToyInlinerInterface +//===----------------------------------------------------------------------===// -Let's start with a simple pattern and try to eliminate a sequence of two -transpose that cancel out: `transpose(transpose(X)) -> X`. Here is the -corresponding Toy example: +/// This class defines the interface for handling inlining with Toy +/// operations. +struct ToyInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; -```Toy(.toy) -def transpose_transpose(x) { - return transpose(transpose(x)); -} -``` + //===--------------------------------------------------------------------===// + // Analysis Hooks + //===--------------------------------------------------------------------===// -Which corresponds to the following IR: - -```MLIR(.mlir) -func @transpose_transpose(%arg0: !toy<"array">) - attributes {toy.generic: true} { - %0 = "toy.transpose"(%arg0) : (!toy<"array">) -> !toy<"array"> - %1 = "toy.transpose"(%0) : (!toy<"array">) -> !toy<"array"> - "toy.return"(%1) : (!toy<"array">) -> () -} -``` - -This is a good example of a transformation that is trivial to match on the Toy -IR but that would be quite hard for LLVM to figure. For example today clang -can't optimize away the temporary array and the computation with the naive -transpose expressed with these loops: - -```c++ -#define N 100 -#define M 100 - -void sink(void *); -void double_transpose(int A[N][M]) { - int B[M][N]; - for(int i = 0; i < N; ++i) { - for(int j = 0; j < M; ++j) { - B[j][i] = A[i][j]; - } + /// All operations within toy can be inlined. + bool isLegalToInline(Operation *, Region *, + BlockAndValueMapping &) const final { + return true; } - for(int i = 0; i < N; ++i) { - for(int j = 0; j < M; ++j) { - A[i][j] = B[j][i]; - } - } - sink(A); -} -``` -For simple rewrite involving matching a tree-like pattern in the IR and -replacing it with a different set of operations, we can plug into the MLIR -`Canonicalizer` pass by implementing a `RewritePattern`: + //===--------------------------------------------------------------------===// + // Transformation Hooks + //===--------------------------------------------------------------------===// -```c++ -/// Fold transpose(transpose(x)) -> x -struct SimplifyRedundantTranspose : public mlir::RewritePattern { - /// We register this pattern to match every toy.transpose in the IR. - /// The "benefit" is used by the framework to order the patterns and process - /// them in order of profitability. - SimplifyRedundantTranspose(mlir::MLIRContext *context) - : RewritePattern(TransposeOp::getOperationName(), /* benefit = */ 1, context) {} + /// Handle the given inlined terminator(toy.return) by replacing it with a new operation + /// as necessary. + void handleTerminator(Operation *op, + ArrayRef valuesToRepl) const final { + // Only "toy.return" needs to be handled here. + auto returnOp = cast(op); - /// This method is attempting to match a pattern and rewrite it. The rewriter - /// argument is the orchestrator of the sequence of rewrites. It is expected - /// to interact with it to perform any changes to the IR from here. - mlir::PatternMatchResult matchAndRewrite( - mlir::Operation *op, mlir::PatternRewriter &rewriter) const override { - // We can directly cast the current operation as this will only get invoked - // on TransposeOp. - TransposeOp transpose = op->cast(); - // look through the input to the current transpose - mlir::Value *transposeInput = transpose.getOperand(); - // If the input is defined by another Transpose, bingo! - if (!matchPattern(transposeInput, mlir::m_Op())) - return matchFailure(); - - auto transposeInputOp = - transposeInput->getDefiningOp()->cast(); - // Use the rewriter to perform the replacement - rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp}); - return matchSuccess(); + // Replace the values directly with the return operands. + assert(returnOp.getNumOperands() == valuesToRepl.size()); + for (const auto &it : llvm::enumerate(returnOp.getOperands())) + valuesToRepl[it.index()]->replaceAllUsesWith(it.value()); } }; ``` -Let's see how to improve our `TransposeOp` by extending it with a new static -method: +Next, we call into the interface by adding an inliner pass to the pass manager +for toy: -```c++ - /// This hook returns any canonicalization pattern rewrites that the operation - /// supports, for use by the canonicalization pass. - static void getCanonicalizationPatterns(mlir::OwningRewritePatternList &results, - mlir::MLIRContext *context) { - results.push_back(std::make_unique(context)); +```Toy(.cpp) + pm.addPass(mlir::createInlinerPass()); +``` + +** Insert example here ** + +## Shape Inference + +The Toy language allows for implicit shapes and hence requires shape inference. +We implement shape inference as a generic +[Operation Interface](../../Interfaces.md#operation-interfaces). + +1. We first create the ShapeInferenceOpInterface by specializing the + OpInterface class using [ODS](../../OpDefinitions.md#operation-interfaces). + This class defines interface methods that Toy operations must override for + shape inference. + +```Toy(.cpp) +def ShapeInferenceOpInterface : OpInterface<"ShapeInferenceOpInterface"> { + let methods = [ + InterfaceMethod< + "bool", "returnsGenericArray", (ins), [{ + if (getNumResults() == 1) { + auto arrayTy = op.getResult()->getType().cast(); + return arrayTy.getShape().empty(); + } + return false; + }]>, + InterfaceMethod<"void", "inferShapes", (ins), [{}]> + ]; +} +``` + +1. Next, we override the inferShapes() method within Toy operations. As an + example, for the transpose op, the result shape is inferred by swapping the + dimensions of the input tensor. + +```Toy(.cpp) + void inferShapes() { + SmallVector dims; + auto arrayTy = getOperand()->getType().cast(); + dims.insert(dims.end(), arrayTy.getShape().begin(), + arrayTy.getShape().end()); + if (dims.size() == 2) + std::swap(dims[0], dims[1]); + getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType())); + return; } ``` -The implementation of this rewriter is in `ToyCombine.cpp`. We also need to -update our main file, `toyc.cpp`, to add an optimization pipeline. In MLIR, the -optimizations are ran through a `PassManager` in a similar way to LLVM: +1. We then create a generic ShapeInference Function pass that uses operation + casting to access the inferShapes() method. This is an intraprocedural shape + inference pass that executes after function inlining and iterates over + operations in a worklist calling inferShapes for each operation with unknown + result shapes. -```c++ -mlir::PassManager pm(ctx); -pm.addPass(mlir::createCanonicalizerPass()); -pm.run(&module); +2. Finally, we call into shape inference pass by adding it to the pass manager + for toy: + +```Toy(.cpp) + pm.addPass(mlir::createShapeInferencePass()); ``` -Finally, we can try to run `toyc test/transpose_transpose.toy -emit=mlir -opt` -and observe our pattern in action: - -```MLIR(.mlir) -func @transpose_transpose(%arg0: !toy<"array">) - attributes {toy.generic: true} { - %0 = "toy.transpose"(%arg0) : (!toy<"array">) -> !toy<"array"> - "toy.return"(%arg0) : (!toy<"array">) -> () -} -``` - -As expected we now directly return the function argument, bypassing any -transpose operation. However one of the transpose hasn't been eliminated. That -is not ideal! What happened is that our pattern replaced the last transform with -the function input and left behind the now dead transpose input. The -Canonicalizer knows to cleanup dead operations, however MLIR conservatively -assumes that operations may have side-effects. We can fix it by adding a new -trait, `HasNoSideEffect`, to our `TransposeOp`: - -```c++ -class TransposeOp : public mlir::Op { -``` - -Let's retry now `toyc test/transpose_transpose.toy -emit=mlir -opt`: - -```MLIR(.mlir) -func @transpose_transpose(%arg0: !toy<"array">) - attributes {toy.generic: true} { - "toy.return"(%arg0) : (!toy<"array">) -> () -} -``` - -Perfect! No `transpose` operation is left, the code is optimal. - -The code in `mlir/ToyCombine.cpp` implements a few more patterns that eliminate -trivial reshapes, or fold them into constants. - -# Shape Inference and Generic Function Specialization - -Our IR operates on generic arrays, we don't know the shape of the arrays other -than during initialization of constants. However we can propagate the shapes -through the computation until they are all known. The issue is how to handle -calls to user-defined generic functions: every call site could deduce different -shapes. One possibility would be to perform symbolic inference based on the -argument types, but this would be hard to generalize if we were to introduce -more control flow in the language. Instead we will proceed by function -specialization: for every call site with new argument shapes we duplicate the -function and specialize it. This is akin to C++ template instantiation: - -``` -template -auto multiply_add(array a, array b) { - auto prod = mul(a, b); - auto sum = add(prod, a); - return sum; -} -``` - -Every new call to `multiply_add` would instantiate the template and emit code -for the specific shape and deduce the return type. Clang implements this -transformation on its AST, but we will implement it in an MLIR pass here. - -The ShapeInferencePass is a `ModulePass`: it will run on the Module as a whole. -MLIR also supports `FunctionPass`es which are restricted to modify a single -function at a time. This pass couldn't be a function pass due the nature of its -interprocedural transformations. - -Implementing such a pass is done by creating a class inheriting from -`mlir::ModulePass` and overriding the `runOnModule()` method: - -``` -class ShapeInferencePass : public mlir::ModulePass { - - void runOnModule() override { - auto &module = getModule(); - ... -``` - -The algorithm has two levels, first intra-procedurally: - -1. Build a worklist containing all the operations that are returning a generic - Toy array: these are the operations that need shape inference. -2. Iterate on the worklist: - - find an operation to process: the next ready operation in the worklist - has all of its arguments non-generic, - - if no operation is found, break out of the loop, - - remove the operation from the worklist, - - infer the shape of its output from the arguments type. -3. If the worklist is empty, the algorithm succeeded and we infer the return - type for the function from the return operation. - -There is a twist though: when a call to a generic function is encountered, shape -inference requires the return type of the callee to be inferred first. At this -point we need to specialize the callee by cloning it. Here is the -inter-procedural flow that wraps the intra-procedural inference: - -1. Keep a worklist of function to process. Start with function "main". -2. While the worklist isn't empty: - - Take the last inserted function in the worklist. - - Run the intra-procedural shape inference on this function. - - If the intra-procedural shape inference can't complete, it returns a - FuncOp that needs to be inferred first. In this case, queue this new - function and continue. Otherwise the inference succeeded and we can pop - from the queue. - -The full code is in `mlir/ShapeInferencePass.cpp`. - -# Future Work: Optimizing Buffer Allocation? - -Toy is value-based. Naively this is a lot of allocation, what if we want to -statically optimize placement? What is the right abstraction level to perform -buffer assignment? +** Insert example here ** diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h index 1b711dc0465c..67ec4a3d3840 100644 --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -141,6 +141,9 @@ std::unique_ptr> createStripDebugInfoPass(); /// Creates a pass which tests loop fusion utilities. std::unique_ptr> createTestLoopFusionPass(); +/// Creates a pass which inlines calls and callable operations as defined by the +/// CallGraph. +std::unique_ptr createInlinerPass(); } // end namespace mlir #endif // MLIR_TRANSFORMS_PASSES_H diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp index 9bb7b0d6e4d2..dbb5381ed708 100644 --- a/mlir/lib/Transforms/Inliner.cpp +++ b/mlir/lib/Transforms/Inliner.cpp @@ -291,4 +291,8 @@ struct InlinerPass : public OperationPass { }; } // end anonymous namespace +std::unique_ptr mlir::createInlinerPass() { + return std::make_unique(); +} + static PassRegistration pass("inline", "Inline function calls"); diff --git a/mlir/test/Examples/Toy/Ch4/ast.toy b/mlir/test/Examples/Toy/Ch4/ast.toy index a0897c0ba0cb..9576c9c5ced0 100644 --- a/mlir/test/Examples/Toy/Ch4/ast.toy +++ b/mlir/test/Examples/Toy/Ch4/ast.toy @@ -10,7 +10,7 @@ def main() { # Define a variable `a` with shape <2, 3>, initialized with the literal value. # The shape is inferred from the supplied literal. var a = [[1, 2, 3], [4, 5, 6]]; - # b is identical to a, the literal array is implicitely reshaped: defining new + # b is identical to a, the literal array is implicitly reshaped: defining new # variables is the way to reshape arrays (element count must match). var b<2, 3> = [1, 2, 3, 4, 5, 6]; # This call will specialize `multiply_transpose` with <2, 3> for both diff --git a/mlir/test/Examples/Toy/Ch4/codegen.toy b/mlir/test/Examples/Toy/Ch4/codegen.toy index ff47f9952465..722ff4a25875 100644 --- a/mlir/test/Examples/Toy/Ch4/codegen.toy +++ b/mlir/test/Examples/Toy/Ch4/codegen.toy @@ -13,20 +13,19 @@ def main() { print(d); } -# CHECK-LABEL: func @multiply_transpose(%arg0: !toy.array, %arg1: !toy.array) -# CHECK-NEXT: attributes {toy.generic = true} { -# CHECK-NEXT: %0 = "toy.transpose"(%arg1) : (!toy.array) -> !toy.array -# CHECK-NEXT: %1 = "toy.mul"(%arg0, %0) : (!toy.array, !toy.array) -> !toy.array -# CHECK-NEXT: "toy.return"(%1) : (!toy.array) -> () -# CHECK-NEXT: } - -# CHECK-LABEL: func @main() { -# CHECK-NEXT: %0 = "toy.constant"() {value = dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> !toy.array<2, 3> -# CHECK-NEXT: %1 = "toy.reshape"(%0) : (!toy.array<2, 3>) -> !toy.array<2, 3> -# CHECK-NEXT: %2 = "toy.constant"() {value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>} : () -> !toy.array<6> -# CHECK-NEXT: %3 = "toy.reshape"(%2) : (!toy.array<6>) -> !toy.array<2, 3> -# CHECK-NEXT: %4 = "toy.generic_call"(%1, %3) {callee = "multiply_transpose"} : (!toy.array<2, 3>, !toy.array<2, 3>) -> !toy.array -# CHECK-NEXT: %5 = "toy.generic_call"(%3, %1) {callee = "multiply_transpose"} : (!toy.array<2, 3>, !toy.array<2, 3>) -> !toy.array -# CHECK-NEXT: "toy.print"(%5) : (!toy.array) -> () -# CHECK-NEXT: "toy.return"() : () -> () +# CHECK-LABEL: func @multiply_transpose( +# CHECK-SAME: [[VAL_0:%.*]]: tensor<*xf64>, [[VAL_1:%.*]]: tensor<*xf64>) +# CHECK-NEXT: attributes {toy.generic} { +# CHECK-NEXT: [[VAL_2:%.*]] = "toy.transpose"([[VAL_1]]) : (tensor<*xf64>) -> tensor<*xf64> +# CHECK-NEXT: [[VAL_3:%.*]] = "toy.mul"([[VAL_0]], [[VAL_2]]) : (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64> +# CHECK-NEXT: "toy.return"([[VAL_3]]) : (tensor<*xf64>) -> () +# CHECK-LABEL: func @main() { +# CHECK-NEXT: [[VAL_4:%.*]] = "toy.constant"() {value = dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64> +# CHECK-NEXT: [[VAL_5:%.*]] = "toy.reshape"([[VAL_4]]) : (tensor<2x3xf64>) -> tensor<2x3xf64> +# CHECK-NEXT: [[VAL_6:%.*]] = "toy.constant"() {value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>} : () -> tensor<6xf64> +# CHECK-NEXT: [[VAL_7:%.*]] = "toy.reshape"([[VAL_6]]) : (tensor<6xf64>) -> tensor<2x3xf64> +# CHECK-NEXT: [[VAL_8:%.*]] = "toy.generic_call"([[VAL_5]], [[VAL_7]]) {callee = @multiply_transpose} : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> +# CHECK-NEXT: [[VAL_9:%.*]] = "toy.generic_call"([[VAL_7]], [[VAL_5]]) {callee = @multiply_transpose} : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> +# CHECK-NEXT: "toy.print"([[VAL_9]]) : (tensor<*xf64>) -> () +# CHECK-NEXT: "toy.return"() : () -> () diff --git a/mlir/test/Examples/Toy/Ch4/invalid.mlir b/mlir/test/Examples/Toy/Ch4/invalid.mlir index d15386640e85..8d1bb27344f2 100644 --- a/mlir/test/Examples/Toy/Ch4/invalid.mlir +++ b/mlir/test/Examples/Toy/Ch4/invalid.mlir @@ -1,11 +1,9 @@ // RUN: not toyc-ch4 %s -emit=mlir 2>&1 - -// This IR is not "valid": +// The following IR is not "valid": // - toy.print should not return a value. // - toy.print should take an argument. // - There should be a block terminator. -// This all round-trip since this is opaque for MLIR. func @main() { - %0 = "toy.print"() : () -> !toy.array<2, 3> + %0 = "toy.print"() : () -> tensor<2x3xf64> } diff --git a/mlir/test/Examples/Toy/Ch4/scalar.toy b/mlir/test/Examples/Toy/Ch4/scalar.toy index 6231fc992c3b..032b3b02b9d9 100644 --- a/mlir/test/Examples/Toy/Ch4/scalar.toy +++ b/mlir/test/Examples/Toy/Ch4/scalar.toy @@ -6,9 +6,9 @@ def main() { } # CHECK-LABEL: func @main() { -# CHECK-NEXT: %0 = "toy.constant"() {value = dense<5.500000e+00> : tensor<1xf64>} : () -> !toy.array<1> -# CHECK-NEXT: %1 = "toy.reshape"(%0) : (!toy.array<1>) -> !toy.array<2, 2> -# CHECK-NEXT: "toy.print"(%1) : (!toy.array<2, 2>) -> () +# CHECK-NEXT: %0 = "toy.constant"() {value = dense<5.500000e+00> : tensor} : () -> tensor +# CHECK-NEXT: %1 = "toy.reshape"(%0) : (tensor) -> tensor<2x2xf64> +# CHECK-NEXT: "toy.print"(%1) : (tensor<2x2xf64>) -> () # CHECK-NEXT: "toy.return"() : () -> () # CHECK-NEXT: } diff --git a/mlir/test/Examples/Toy/Ch4/transpose_transpose.toy b/mlir/test/Examples/Toy/Ch4/transpose_transpose.toy deleted file mode 100644 index 31399eee53f5..000000000000 --- a/mlir/test/Examples/Toy/Ch4/transpose_transpose.toy +++ /dev/null @@ -1,19 +0,0 @@ -# RUN: toyc-ch4 %s -emit=mlir 2>&1 | FileCheck %s -# RUN: toyc-ch4 %s -emit=mlir -opt 2>&1 | FileCheck %s --check-prefix=OPT - -def transpose_transpose(x) { - return transpose(transpose(x)); -} - -def main() { - print(transpose_transpose([[1, 2], [3, 4]])); -} - -#CHECK-LABEL: func @transpose_transpose -#CHECK: transpose -#CHECK-LABEL: main - - -#OPT-LABEL: func @transpose_transpose -#OPT-NOT: transpose - diff --git a/mlir/test/Examples/Toy/Ch4/trivialReshape.toy b/mlir/test/Examples/Toy/Ch4/trivialReshape.toy deleted file mode 100644 index c7a805d89ef4..000000000000 --- a/mlir/test/Examples/Toy/Ch4/trivialReshape.toy +++ /dev/null @@ -1,24 +0,0 @@ -# RUN: toyc-ch4 %s -emit=mlir 2>&1 | FileCheck %s -# RUN: toyc-ch4 %s -emit=mlir -opt 2>&1 | FileCheck %s --check-prefix=OPT - -# We expect no reshape in this function with optimizations enabled -def foo(a) { - var b<2,1> = a; - var c<2,1> = b; - print(c); -} - -def main() { - var a<2, 1> = [1, 2]; - foo(a); -} - -# without optimizations, match the reshape -#CHECK-LABEL: func @foo -#CHECK: reshape -#CHECK-LABEL: main - -# with optimizations, ensure no reshape -#OPT-LABEL: main -#OPT-LABEL: func @foo_2x1 -#OPT-NOT: reshape