From 1ba9bb05078aee74420bc64394a20c782f13a125 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Wed, 16 Oct 2019 17:33:34 -0700 Subject: [PATCH] Add Ch.5 of the toy tutorial. This chapter adds a partial lowering of toy operations, all but PrintOp, to a combination of the Affine and Std dialects. This chapter focuses on introducing the conversion framework, the benefits of partial lowering, and how easily dialects may co-exist in the IR. PiperOrigin-RevId: 275150649 --- mlir/examples/toy/Ch3/mlir/Dialect.cpp | 9 +- mlir/examples/toy/Ch4/mlir/Dialect.cpp | 6 +- mlir/examples/toy/Ch4/mlir/MLIRGen.cpp | 9 +- mlir/examples/toy/Ch5/CMakeLists.txt | 40 +- mlir/examples/toy/Ch5/include/CMakeLists.txt | 1 + mlir/examples/toy/Ch5/include/toy/AST.h | 15 +- .../toy/Ch5/include/toy/CMakeLists.txt | 9 + mlir/examples/toy/Ch5/include/toy/Dialect.h | 364 +----------- mlir/examples/toy/Ch5/include/toy/Lexer.h | 2 +- mlir/examples/toy/Ch5/include/toy/Lowering.h | 45 -- mlir/examples/toy/Ch5/include/toy/Ops.td | 272 +++++++++ mlir/examples/toy/Ch5/include/toy/Passes.h | 12 +- .../Ch5/include/toy/ShapeInferenceInterface.h | 37 ++ .../include/toy/ShapeInferenceInterface.td | 38 ++ .../Ch5/mlir/DeadFunctionEliminationPass.cpp | 68 +++ mlir/examples/toy/Ch5/mlir/Dialect.cpp | 256 ++++++++ mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp | 148 ----- mlir/examples/toy/Ch5/mlir/LateLowering.cpp | 470 --------------- .../toy/Ch5/mlir/LowerToAffineLoops.cpp | 318 ++++++++++ mlir/examples/toy/Ch5/mlir/MLIRGen.cpp | 414 +++++++------ .../toy/Ch5/mlir/ShapeInferencePass.cpp | 357 ++---------- mlir/examples/toy/Ch5/mlir/ToyCombine.cpp | 161 +---- mlir/examples/toy/Ch5/mlir/ToyCombine.td | 73 +++ mlir/examples/toy/Ch5/mlir/ToyDialect.cpp | 403 ------------- mlir/examples/toy/Ch5/toyc.cpp | 258 +++----- mlir/g3doc/Tutorials/Toy/Ch-5.md | 549 ++++++++++-------- mlir/test/Examples/Toy/Ch4/ast.toy | 57 +- .../Examples/Toy/Ch5/affine-lowering.mlir | 65 +++ mlir/test/Examples/Toy/Ch5/ast.toy | 59 +- mlir/test/Examples/Toy/Ch5/codegen.toy | 31 +- mlir/test/Examples/Toy/Ch5/invalid.mlir | 6 +- mlir/test/Examples/Toy/Ch5/lowering.toy | 16 - mlir/test/Examples/Toy/Ch5/scalar.toy | 6 +- .../Examples/Toy/Ch5/shape_inference.mlir | 30 + .../Examples/Toy/Ch5/transpose_transpose.toy | 19 - mlir/test/Examples/Toy/Ch5/trivialReshape.toy | 24 - 36 files changed, 1961 insertions(+), 2686 deletions(-) create mode 100644 mlir/examples/toy/Ch5/include/CMakeLists.txt create mode 100644 mlir/examples/toy/Ch5/include/toy/CMakeLists.txt delete mode 100644 mlir/examples/toy/Ch5/include/toy/Lowering.h create mode 100644 mlir/examples/toy/Ch5/include/toy/Ops.td create mode 100644 mlir/examples/toy/Ch5/include/toy/ShapeInferenceInterface.h create mode 100644 mlir/examples/toy/Ch5/include/toy/ShapeInferenceInterface.td create mode 100644 mlir/examples/toy/Ch5/mlir/DeadFunctionEliminationPass.cpp create mode 100644 mlir/examples/toy/Ch5/mlir/Dialect.cpp delete mode 100644 mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp delete mode 100644 mlir/examples/toy/Ch5/mlir/LateLowering.cpp create mode 100644 mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp create mode 100644 mlir/examples/toy/Ch5/mlir/ToyCombine.td delete mode 100644 mlir/examples/toy/Ch5/mlir/ToyDialect.cpp create mode 100644 mlir/test/Examples/Toy/Ch5/affine-lowering.mlir delete mode 100644 mlir/test/Examples/Toy/Ch5/lowering.toy create mode 100644 mlir/test/Examples/Toy/Ch5/shape_inference.mlir delete mode 100644 mlir/test/Examples/Toy/Ch5/transpose_transpose.toy delete mode 100644 mlir/test/Examples/Toy/Ch5/trivialReshape.toy diff --git a/mlir/examples/toy/Ch3/mlir/Dialect.cpp b/mlir/examples/toy/Ch3/mlir/Dialect.cpp index 375533b880c8..2688fe6b12a5 100644 --- a/mlir/examples/toy/Ch3/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch3/mlir/Dialect.cpp @@ -55,7 +55,8 @@ static void buildConstantOp(mlir::Builder *builder, mlir::OperationState &state, ConstantOp::build(builder, state, dataType, dataAttribute); } -/// Verifier for constant operation. +/// Verifier for the constant operation. This corresponds to the `::verify(...)` +/// in the op definition. static mlir::LogicalResult verify(ConstantOp op) { // If the return type of the constant is not an unranked tensor, the shape // must match the shape of the attribute holding the data. @@ -63,6 +64,8 @@ static mlir::LogicalResult verify(ConstantOp op) { if (!resultType) return success(); + // Check that the rank of the attribute type matches the rank of the constant + // result type. auto attrType = op.value().getType().cast(); if (attrType.getRank() != resultType.getRank()) { return op.emitOpError( @@ -70,7 +73,9 @@ static mlir::LogicalResult verify(ConstantOp op) { "attribute: ") << attrType.getRank() << " != " << resultType.getRank(); } - for (int dim = 0; dim < attrType.getRank(); ++dim) { + + // Check that each of the dimensions match between the two types. + for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) { if (attrType.getShape()[dim] != resultType.getShape()[dim]) { return op.emitOpError( "return type shape mismatches its attribute at dimension ") diff --git a/mlir/examples/toy/Ch4/mlir/Dialect.cpp b/mlir/examples/toy/Ch4/mlir/Dialect.cpp index d33dbd8be0aa..2ac86cf99846 100644 --- a/mlir/examples/toy/Ch4/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch4/mlir/Dialect.cpp @@ -118,6 +118,8 @@ static mlir::LogicalResult verify(ConstantOp op) { if (!resultType) return success(); + // Check that the rank of the attribute type matches the rank of the constant + // result type. auto attrType = op.value().getType().cast(); if (attrType.getRank() != resultType.getRank()) { return op.emitOpError( @@ -125,7 +127,9 @@ static mlir::LogicalResult verify(ConstantOp op) { "attribute: ") << attrType.getRank() << " != " << resultType.getRank(); } - for (int dim = 0; dim < attrType.getRank(); ++dim) { + + // Check that each of the dimensions match between the two types. + for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) { if (attrType.getShape()[dim] != resultType.getShape()[dim]) { return op.emitOpError( "return type shape mismatches its attribute at dimension ") diff --git a/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp index ace52aff2bf0..5f12d0a8798a 100644 --- a/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch4/mlir/MLIRGen.cpp @@ -79,7 +79,7 @@ public: // the structural properties of the IR and invoke any specific verifiers we // have on the Toy operations. if (failed(mlir::verify(theModule))) { - theModule.emitError("Module verification error"); + theModule.emitError("module verification error"); return nullptr; } @@ -229,7 +229,7 @@ private: if (auto *variable = symbolTable.lookup(expr.getName())) return variable; - emitError(loc(expr.loc()), "Error: unknown variable '") + emitError(loc(expr.loc()), "error: unknown variable '") << expr.getName() << "'"; return nullptr; } @@ -289,7 +289,8 @@ private: auto dataAttribute = mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(data)); - // Build the MLIR op `toy.constant`. + // Build the MLIR op `toy.constant`. This invokes the `ConstantOp::build` + // method. return builder.create(loc(lit.loc()), type, dataAttribute); } @@ -389,7 +390,7 @@ private: auto init = vardecl.getInitVal(); if (!init) { emitError(loc(vardecl.loc()), - "Missing initializer in variable declaration"); + "missing initializer in variable declaration"); return nullptr; } diff --git a/mlir/examples/toy/Ch5/CMakeLists.txt b/mlir/examples/toy/Ch5/CMakeLists.txt index 1b20523ae4c2..df5239589de2 100644 --- a/mlir/examples/toy/Ch5/CMakeLists.txt +++ b/mlir/examples/toy/Ch5/CMakeLists.txt @@ -1,40 +1,42 @@ +add_subdirectory(include) + set(LLVM_LINK_COMPONENTS Support ) +set(LLVM_TARGET_DEFINITIONS mlir/ToyCombine.td) +mlir_tablegen(ToyCombine.inc -gen-rewriters "-I${CMAKE_CURRENT_SOURCE_DIR}/include") +add_public_tablegen_target(ToyCh5CombineIncGen) + add_toy_chapter(toyc-ch5 toyc.cpp parser/AST.cpp - mlir/EarlyLowering.cpp - mlir/LateLowering.cpp mlir/MLIRGen.cpp + mlir/Dialect.cpp + mlir/DeadFunctionEliminationPass.cpp + mlir/LowerToAffineLoops.cpp mlir/ShapeInferencePass.cpp - mlir/ToyDialect.cpp mlir/ToyCombine.cpp ) + +add_dependencies(toyc-ch5 ToyCh5ShapeInferenceInterfaceIncGen) +add_dependencies(toyc-ch5 ToyCh5OpsIncGen) +add_dependencies(toyc-ch5 ToyCh5CombineIncGen) +add_dependencies(toyc-ch5 MLIRCallOpInterfacesIncGen) include_directories(include/) -include_directories(../../Linalg/Linalg1/include/) -include_directories(../../Linalg/Linalg2/include/) -include_directories(../../Linalg/Linalg3/include/) +include_directories(${CMAKE_CURRENT_BINARY_DIR}) +include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/) target_link_libraries(toyc-ch5 PRIVATE - Linalg3DialectConstruction - Linalg3 - Linalg2 - Linalg1 + MLIRAffineOps MLIRAnalysis - MLIREDSC - MLIRExecutionEngine MLIRIR - MLIRLLVMIR MLIRParser MLIRPass - MLIRTargetLLVMIR - MLIRTransforms - MLIRSupport -) + MLIRStandardOps + MLIRTransforms) + whole_archive_link(toyc-ch5 MLIRAffineOps MLIRStandardOps -) - + ) diff --git a/mlir/examples/toy/Ch5/include/CMakeLists.txt b/mlir/examples/toy/Ch5/include/CMakeLists.txt new file mode 100644 index 000000000000..37c89d0bae96 --- /dev/null +++ b/mlir/examples/toy/Ch5/include/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(toy) diff --git a/mlir/examples/toy/Ch5/include/toy/AST.h b/mlir/examples/toy/Ch5/include/toy/AST.h index 456a32309c40..2ad3392c11ac 100644 --- a/mlir/examples/toy/Ch5/include/toy/AST.h +++ b/mlir/examples/toy/Ch5/include/toy/AST.h @@ -33,10 +33,9 @@ namespace toy { -/// A variable +/// A variable type with shape information. struct VarType { - enum { TY_FLOAT, TY_INT } elt_ty; - std::vector shape; + std::vector shape; }; /// Base class for all expression nodes. @@ -50,9 +49,7 @@ public: Expr_Var, Expr_BinOp, Expr_Call, - Expr_Print, // builtin - Expr_If, - Expr_For, + Expr_Print, }; ExprAST(ExprASTKind kind, Location location) @@ -85,7 +82,7 @@ public: static bool classof(const ExprAST *C) { return C->getKind() == Expr_Num; } }; -/// +/// Expression class for a literal value. class LiteralExprAST : public ExprAST { std::vector> values; std::vector dims; @@ -116,7 +113,7 @@ public: static bool classof(const ExprAST *C) { return C->getKind() == Expr_Var; } }; -/// +/// Expression class for defining a variable. class VarDeclExprAST : public ExprAST { std::string name; VarType type; @@ -136,7 +133,7 @@ public: static bool classof(const ExprAST *C) { return C->getKind() == Expr_VarDecl; } }; -/// +/// Expression class for a return operator. class ReturnExprAST : public ExprAST { llvm::Optional> expr; diff --git a/mlir/examples/toy/Ch5/include/toy/CMakeLists.txt b/mlir/examples/toy/Ch5/include/toy/CMakeLists.txt new file mode 100644 index 000000000000..aaa932896d0f --- /dev/null +++ b/mlir/examples/toy/Ch5/include/toy/CMakeLists.txt @@ -0,0 +1,9 @@ +set(LLVM_TARGET_DEFINITIONS Ops.td) +mlir_tablegen(Ops.h.inc -gen-op-decls "-I${CMAKE_CURRENT_SOURCE_DIR}/..") +mlir_tablegen(Ops.cpp.inc -gen-op-defs "-I${CMAKE_CURRENT_SOURCE_DIR}/..") +add_public_tablegen_target(ToyCh5OpsIncGen) + +set(LLVM_TARGET_DEFINITIONS ShapeInferenceInterface.td) +mlir_tablegen(ShapeInferenceOpInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(ShapeInferenceOpInterfaces.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(ToyCh5ShapeInferenceInterfaceIncGen) diff --git a/mlir/examples/toy/Ch5/include/toy/Dialect.h b/mlir/examples/toy/Ch5/include/toy/Dialect.h index ee85dbec9bcd..556ae972b84e 100644 --- a/mlir/examples/toy/Ch5/include/toy/Dialect.h +++ b/mlir/examples/toy/Ch5/include/toy/Dialect.h @@ -16,7 +16,7 @@ // ============================================================================= // // This file implements the IR Dialect for the Toy language. -// See g3doc/Tutorials/Toy/Ch-3.md for more information. +// See g3doc/Tutorials/Toy/Ch-2.md for more information. // //===----------------------------------------------------------------------===// @@ -25,369 +25,31 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/Function.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/OpImplementation.h" #include "mlir/IR/StandardTypes.h" -#include "mlir/IR/TypeSupport.h" -#include "mlir/IR/Types.h" +#include "toy/ShapeInferenceInterface.h" namespace mlir { -class Builder; -} - namespace toy { /// This is the definition of the Toy dialect. A dialect inherits from -/// mlir::Dialect and register custom operations and types (in its constructor). -/// It can also overriding general behavior of dialects exposed as virtual -/// method, for example regarding verification and parsing/printing. +/// mlir::Dialect and registers custom attributes, operations, and types (in its +/// constructor). It can also override some general behavior exposed via virtual +/// methods. class ToyDialect : public mlir::Dialect { public: explicit ToyDialect(mlir::MLIRContext *ctx); - /// Parse a type registered to this dialect. Overriding this method is - /// required for dialects that have custom types. - /// Technically this is only needed to be able to round-trip to textual IR. - mlir::Type parseType(llvm::StringRef tyData, - mlir::Location loc) const override; - - /// Print a type registered to this dialect. Overriding this method is - /// only required for dialects that have custom types. - /// Technically this is only needed to be able to round-trip to textual IR. - void printType(mlir::Type type, llvm::raw_ostream &os) const override; + /// Provide a utility accessor to the dialect namespace. This is used by + /// several utilities for casting between dialects. + static llvm::StringRef getDialectNamespace() { return "toy"; } }; -//////////////////////////////////////////////////////////////////////////////// -/////////////////////// Custom Types for the Dialect /////////////////////////// -//////////////////////////////////////////////////////////////////////////////// - -namespace detail { -struct ToyArrayTypeStorage; -} - -/// LLVM-style RTTI: one entry per subclass to allow dyn_cast/isa. -enum ToyTypeKind { - // The enum starts at the range reserved for this dialect. - TOY_TYPE = mlir::Type::FIRST_TOY_TYPE, - TOY_ARRAY, -}; - -/// Type for Toy arrays. -/// In MLIR Types are reference to immutable and uniqued objects owned by the -/// MLIRContext. As such `ToyArrayType` only wraps a pointer to an uniqued -/// instance of `ToyArrayTypeStorage` (defined in our implementation file) and -/// provides the public facade API to interact with the type. -class ToyArrayType : public mlir::Type::TypeBase { -public: - using Base::Base; - - /// Returns the dimensions for this array, or and empty range for a generic - /// array. - llvm::ArrayRef getShape(); - - /// Predicate to test if this array is generic (shape haven't been inferred - /// yet). - bool isGeneric() { return getShape().empty(); } - - /// Return the rank of this array (0 if it is generic). - int getRank() { return getShape().size(); } - - /// Return the type of individual elements in the array. - mlir::Type getElementType(); - - /// Get a MemRef equivalent to this array type. - mlir::MemRefType toMemref(); - - /// Get the unique instance of this Type from the context. - /// A ToyArrayType is only defined by the shape of the array. - static ToyArrayType get(mlir::MLIRContext *context, - llvm::ArrayRef shape = {}); - - /// Support method to enable LLVM-style RTTI type casting. - static bool kindof(unsigned kind) { return kind == ToyTypeKind::TOY_ARRAY; } -}; - -//////////////////////////////////////////////////////////////////////////////// -//////////////////// Custom Operations for the Dialect ///////////////////////// -//////////////////////////////////////////////////////////////////////////////// - -/// Constant operation turns a literal into an SSA value. The data is attached -/// to the operation as an attribute. For example: -/// -/// %0 = "toy.constant"() -/// {value: dense, [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]>} -/// : () -> !toy<"array<2, 3>"> -/// -/// An operation inherits from `class Op` and specifies optional traits. Here we -/// indicate that `toy.constant` does not have any operands and returns a single -/// result. The traits provide some utilities methods for the operation, for -/// instance we will be able to use `getResult()`, but `getOperand()` won't be -/// available. -class ConstantOp : public mlir::Op { -public: - /// This is the name used by MLIR to match an operation to this class during - /// parsing. - static llvm::StringRef getOperationName() { return "toy.constant"; } - - /// The operation can have extra verification beyond the traits they define. - mlir::LogicalResult verify(); - - /// Interface to mlir::Builder::create(...) - /// This method populates the `state` that MLIR uses to create operations. - /// The `toy.constant` operation does not have arguments but attaches a - /// constant array as an attribute and returns it as an SSA value. - static void build(mlir::Builder *builder, mlir::OperationState &state, - llvm::ArrayRef shape, - mlir::DenseElementsAttr value); - - /// Similar to the one above, but takes a single float and returns a - /// !toy<"array<1>">. - static void build(mlir::Builder *builder, mlir::OperationState &state, - mlir::FloatAttr value); - - mlir::DenseElementsAttr getValue() { - return getAttr("value").cast(); - } - - /// Inherit constructor. - using Op::Op; -}; - -/// Generic calls represent calls to a user defined function that needs to -/// be specialized for the shape of its arguments. The callee name is attached -/// as a literal string as an attribute. The arguments list must match the -/// arguments expected by the callee. For example: -/// -/// %4 = "toy.generic_call"(%1, %3) {callee: "my_func"} -/// : (!toy<"array<2, 3>">, !toy<"array<2, 3>">) -> !toy<"array"> -/// -/// This is only valid if a function named "my_func" exists and takes two -/// arguments. -class GenericCallOp - : public mlir::Op { -public: - /// MLIR will use this to register the operation with the parser/printer. - static llvm::StringRef getOperationName() { return "toy.generic_call"; } - - /// Operations can add custom verification beyond the traits they define. - mlir::LogicalResult verify(); - - /// Interface to the builder to allow: - /// mlir::Builder::create(...) - /// This method populate the `state` that MLIR use to create operations. - /// The `toy.generic_call` operation accepts a callee name and a list of - /// arguments for the call. - static void build(mlir::Builder *builder, mlir::OperationState &state, - llvm::StringRef callee, - llvm::ArrayRef arguments); - - /// Return the name of the callee. - llvm::StringRef getCalleeName(); - - /// Inherit constructor. - using Op::Op; -}; - -/// Return operations terminate blocks (and functions as well). They take a -/// single argument and the type must match the function return type. -class ReturnOp - : public mlir::Op { -public: - static llvm::StringRef getOperationName() { return "toy.return"; } - - /// Operations can add custom verification beyond the traits they define. - mlir::LogicalResult verify(); - - /// Interface to mlir::Builder::create(...) - /// This method populate the `state` that MLIR use to create operations. - /// The `toy.return` operation accepts an optional single array as an argument - /// and does not have any returned value. - static void build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *value = nullptr); - - /// Return true if there is a returned value. - bool hasOperand() { return 0 != getNumOperands(); } - - /// Helper to return the optional operand. Caller must check if the operand - /// is present before calling this. - mlir::Value *getOperand() { return getOperation()->getOperand(0); } - - /// Inherit constructor. - using Op::Op; -}; - -/// The print builtin takes a single array argument and does not return any. -class PrintOp : public mlir::Op { -public: - static llvm::StringRef getOperationName() { return "toy.print"; } - - /// Operations can add custom verification beyond the traits they define. - mlir::LogicalResult verify(); - - /// Interface to mlir::Builder::create(...) - /// This method populate the `state` that MLIR use to create operations. - /// The `toy.print` operation accepts a single array as argument and does - /// not have any returned value. - static void build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *value); - - /// Inherit constructor. - using Op::Op; -}; - -class TransposeOp : public mlir::Op { -public: - static llvm::StringRef getOperationName() { return "toy.transpose"; } - - /// Operation can add custom verification beyond the traits they define. - mlir::LogicalResult verify(); - - /// Interface to mlir::Builder::create(...) - /// This method populate the `state` that MLIR use to create operations. - /// The `toy.transpose` operation accepts a single array as argument and - /// returns the transposed array as its only result. - static void build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *value); - - // Register our patterns for rewrite by the Canonicalization framework. - static void - getCanonicalizationPatterns(mlir::OwningRewritePatternList &results, - mlir::MLIRContext *context); - - /// Inherit constructor. - using Op::Op; -}; - -/// Reshape operation is transforming its input array into a new array with the -/// same number of elements but different shapes. For example: -/// -/// %0 = "toy.transpose"(%arg1) : (!toy<"array<10>">) -> !toy<"array<5, 2>"> -/// -class ReshapeOp : public mlir::Op { -public: - static llvm::StringRef getOperationName() { return "toy.reshape"; } - - /// Operation can add custom verification beyond the traits they define. - mlir::LogicalResult verify(); - - /// Interface to mlir::Builder::create(...) - /// This method populate the `state` that MLIR use to create operations. - /// The `toy.reshape` operation accepts a single array as argument and - /// returns the array with the specified reshapedType as its only result. - static void build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *value, ToyArrayType reshapedType); - - // Register our patterns for rewrite by the Canonicalization framework. - static void - getCanonicalizationPatterns(mlir::OwningRewritePatternList &results, - mlir::MLIRContext *context); - - /// Inherit constructor. - using Op::Op; -}; - -/// Binary operation implementing a multiplication. For two-dimensional array -/// a matrix multiplication is implemented, while for one dimensional array a -/// dot product is performed. -class MulOp : public mlir::Op::Impl, - mlir::OpTrait::OneResult, - mlir::OpTrait::HasNoSideEffect> { -public: - static llvm::StringRef getOperationName() { return "toy.mul"; } - - /// Operation can add custom verification beyond the traits they define. - mlir::LogicalResult verify(); - - /// Interface to mlir::Builder::create(...) - /// This method populate the `state` that MLIR use to create operations. - /// The `toy.mul` operation accepts two operands as argument and returns - /// a single value. - static void build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *lhs, mlir::Value *rhs); - - /// Convenience accessor for LHS of the expression. - mlir::Value *getLHS() { return getOperand(0); } - - /// Convenience accessor for RHS of the expression. - mlir::Value *getRHS() { return getOperand(1); } - - /// Inherit constructor. - using Op::Op; -}; - -/// Element wise addition of two arrays. The shape must match. -class AddOp : public mlir::Op::Impl, - mlir::OpTrait::OneResult, - mlir::OpTrait::HasNoSideEffect> { -public: - static llvm::StringRef getOperationName() { return "toy.add"; } - - /// Operation can add custom verification beyond the traits they define. - mlir::LogicalResult verify(); - - /// Interface to mlir::Builder::create(...) - /// This method populate the `state` that MLIR use to create operations. - /// The `toy.mul` operation accepts two operands as argument and returns - /// a single value. - static void build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *lhs, mlir::Value *rhs); - - /// Convenience accessor for LHS of the expression. - mlir::Value *getLHS() { return getOperand(0); } - - /// Convenience accessor for RHS of the expression. - mlir::Value *getRHS() { return getOperand(1); } - - /// Inherit constructor. - using Op::Op; -}; - -/// AllocOp is a temporary operation for buffer allocation, created as part of -/// partial lowering. -class AllocOp : public mlir::Op { -public: - static llvm::StringRef getOperationName() { return "toy.alloc"; } - - /// Interface to mlir::Builder::create(...) - /// This method populate the `state` that MLIR use to create operations. - /// `toy.alloc` does not have any argument and returns a toy array. - static void build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Type retType); - - /// Inherit constructor. - using Op::Op; -}; - -/// FIXME: should be in std? -class TypeCastOp : public mlir::Op { -public: - static llvm::StringRef getOperationName() { return "toy.cast"; } - - static void build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *value, mlir::Type destTy); - - // Register our patterns for rewrite by the Canonicalization framework. - static void - getCanonicalizationPatterns(mlir::OwningRewritePatternList &results, - mlir::MLIRContext *context); - - /// Inherit constructor. - using Op::Op; -}; +/// Include the auto-generated header file containing the declarations of the +/// toy operations. +#define GET_OP_CLASSES +#include "toy/Ops.h.inc" } // end namespace toy +} // end namespace mlir #endif // MLIR_TUTORIAL_TOY_DIALECT_H_ diff --git a/mlir/examples/toy/Ch5/include/toy/Lexer.h b/mlir/examples/toy/Ch5/include/toy/Lexer.h index d73adb9706b7..21f92614912e 100644 --- a/mlir/examples/toy/Ch5/include/toy/Lexer.h +++ b/mlir/examples/toy/Ch5/include/toy/Lexer.h @@ -31,7 +31,7 @@ namespace toy { /// Structure definition a location in a file. struct Location { - std::shared_ptr file; ///< filename + std::shared_ptr file; ///< filename. int line; ///< line number. int col; ///< column number. }; diff --git a/mlir/examples/toy/Ch5/include/toy/Lowering.h b/mlir/examples/toy/Ch5/include/toy/Lowering.h deleted file mode 100644 index 4788ea3fbebe..000000000000 --- a/mlir/examples/toy/Ch5/include/toy/Lowering.h +++ /dev/null @@ -1,45 +0,0 @@ -//===- Lowering.h - Lexer for the Toy language ----------------------------===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// This file exposes the interface to the lowering for Toy. It is divided in -// two parts: an *early lowering* that emits operations in the `Linalg` -// dialects for a subset of the Toy IR, and a *late lowering* that materializes -// buffers and converts all operations and type to the LLVM dialect. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_EXAMPLES_TOY_LOWERING_H_ -#define MLIR_EXAMPLES_TOY_LOWERING_H_ - -#include - -namespace mlir { -class Pass; -class DialectConversion; -} // namespace mlir - -namespace toy { -/// Create a pass for lowering operations in the `Linalg` dialects, for a subset -/// of the Toy IR (matmul). -std::unique_ptr createEarlyLoweringPass(); - -/// Create a pass for the late lowering toward LLVM dialect. -std::unique_ptr createLateLoweringPass(); - -} // namespace toy - -#endif // MLIR_EXAMPLES_TOY_LOWERING_H_ diff --git a/mlir/examples/toy/Ch5/include/toy/Ops.td b/mlir/examples/toy/Ch5/include/toy/Ops.td new file mode 100644 index 000000000000..8252b3e3a21c --- /dev/null +++ b/mlir/examples/toy/Ch5/include/toy/Ops.td @@ -0,0 +1,272 @@ +//===- Ops.td - Toy dialect operation definitions ----------*- tablegen -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// Defines the operations of the Toy dialect. +// +//===----------------------------------------------------------------------===// + +#ifdef TOY_OPS +#else +#define TOY_OPS + +#ifdef MLIR_CALLINTERFACES +#else +include "mlir/Analysis/CallInterfaces.td" +#endif // MLIR_CALLINTERFACES + +#ifdef SHAPE_INFERENCE_INTERFACE +#else +include "toy/ShapeInferenceInterface.td" +#endif // SHAPE_INFERENCE_INTERFACE + +// Provide a definition of the 'toy' dialect in the ODS framework so that we +// can define our operations. +def Toy_Dialect : Dialect { + let name = "toy"; + let cppNamespace = "toy"; +} + +// Base class for toy dialect operations. This operation inherits from the base +// `Op` class in OpBase.td, and provides: +// * The parent dialect of the operation. +// * The mnemonic for the operation, or the name without the dialect prefix. +// * A list of traits for the operation. +class Toy_Op traits = []> : + Op; + +//===----------------------------------------------------------------------===// +// Toy Operations +//===----------------------------------------------------------------------===// + +// We define a toy operation by inherting from our base 'Toy_Op' class above. +// Here we provide the mnemonic and a list of traits for the operation. The +// constant operation is marked as 'NoSideEffect' as it is a pure operation +// and may be removed if dead. +def ConstantOp : Toy_Op<"constant", [NoSideEffect]> { + // Provide a summary and description for this operation. This can be used to + // auto-generate documenatation of the operations within our dialect. + let summary = "constant"; + let description = [{ + Constant operation turns a literal into an SSA value. The data is attached + to the operation as an attribute. For example: + + ```mlir + %0 = "toy.constant"() + { value = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf64> } + : () -> tensor<2x3xf64> + ``` + }]; + + // The constant operation takes an attribute as the only input. + let arguments = (ins F64ElementsAttr:$value); + + // The constant operation returns a single value of TensorType. + let results = (outs F64Tensor); + + // Add custom build methods for the constant operation. These method populates + // the `state` that MLIR uses to create operations, i.e. these are used when + // using `builder.create(...)`. + let builders = [ + // Build a constant with a given constant tensor value. + OpBuilder<"Builder *builder, OperationState &result, " + "DenseElementsAttr value", [{ + build(builder, result, value.getType(), value); + }]>, + + // Build a constant with a given constant floating-point value. + OpBuilder<"Builder *builder, OperationState &result, double value", [{ + buildConstantOp(builder, result, value); + }]> + ]; + + // Invoke a static verify method to verify this constant operation. + let verifier = [{ return ::verify(*this); }]; +} + +def AddOp : Toy_Op<"add", + [NoSideEffect, DeclareOpInterfaceMethods]> { + let summary = "element-wise addition operation"; + let description = [{ + The "add" operation performs element-wise addition between two tensors. + The shapes of the tensor operands are expected to match. + }]; + + let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs); + let results = (outs F64Tensor); + + // Allow building an AddOp with from the two input operands. + let builders = [ + OpBuilder<"Builder *b, OperationState &result, Value *lhs, Value *rhs", [{ + buildAddOp(b, result, lhs, rhs); + }] + >]; +} + +def CastOp : Toy_Op<"cast", + [DeclareOpInterfaceMethods, NoSideEffect, + SameOperandsAndResultShape]> { + let summary = "shape cast operation"; + let description = [{ + The "cast" operation converts a tensor from one type to an equivalent type + without changing any data elements. The source and destination types + must both be tensor types with the same element type. If both are ranked + then the rank should be the same and static dimensions should match. The + operation is invalid if converting to a mismatching constant dimension. + }]; + + let arguments = (ins F64Tensor:$input); + let results = (outs F64Tensor:$output); + + // Set the folder bit so that we can fold redundant cast operations. + let hasFolder = 1; +} + +def GenericCallOp : Toy_Op<"generic_call", + [DeclareOpInterfaceMethods]> { + let summary = "generic call operation"; + let description = [{ + Generic calls represent calls to a user defined function that needs to + be specialized for the shape of its arguments. The callee name is attached + as a symbol reference via an attribute. The arguments list must match the + arguments expected by the callee. For example: + + ```mlir + %4 = "toy.generic_call"(%1, %3) {callee = @my_func} + : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> + ``` + + This is only valid if a function named "my_func" exists and takes two + arguments. + }]; + + // The generic call operation takes a symbol reference attribute as the + // callee, and inputs for the call. + let arguments = (ins SymbolRefAttr:$callee, Variadic:$inputs); + + // The generic call operation returns a single value of TensorType. + let results = (outs F64Tensor); + + // Add custom build methods for the generic call operation. + let builders = [ + // Build a constant with a given constant tensor value. + OpBuilder<"Builder *builder, OperationState &result, " + "StringRef callee, ArrayRef arguments", [{ + buildGenericCallOp(builder, result, callee, arguments); + }]> + ]; +} + +def MulOp : Toy_Op<"mul", + [NoSideEffect, DeclareOpInterfaceMethods]> { + let summary = "element-wise multiplication operation"; + let description = [{ + The "mul" operation performs element-wise multiplication between two + tensors. The shapes of the tensor operands are expected to match. + }]; + + let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs); + let results = (outs F64Tensor); + + // Allow building a MulOp with from the two input operands. + let builders = [ + OpBuilder<"Builder *b, OperationState &result, Value *lhs, Value *rhs", [{ + buildMulOp(b, result, lhs, rhs); + }] + >]; +} + +def PrintOp : Toy_Op<"print"> { + let summary = "print operation"; + let description = [{ + The "print" builtin operation prints a given input tensor, and produces + no results. + }]; + + // The print operation takes an input tensor to print. + // We also allow a F64MemRef to enable interop during partial lowering. + let arguments = (ins AnyTypeOf<[F64Tensor, F64MemRef]>:$input); +} + +def ReshapeOp : Toy_Op<"reshape", [NoSideEffect]> { + let summary = "tensor reshape operation"; + let description = [{ + Reshape operation is transforming its input tensor into a new tensor with + the same number of elements but different shapes. For example: + + ```mlir + %0 = "toy.reshape"(%arg1) : (tensor<10xf64>) -> tensor<5x2xf64> + ``` + }]; + + let arguments = (ins F64Tensor:$input); + let hasCanonicalizer = 1; + + // We expect that the reshape operation returns a statically shaped tensor. + let results = (outs StaticShapeTensorOf<[F64]>); +} + +def ReturnOp : Toy_Op<"return", [Terminator, HasParent<"FuncOp">]> { + let summary = "return operation"; + let description = [{ + The "return" operation represents a return operation within a function. + The operation takes an optional tensor operand and produces no results. + The operand type must match the signature of the function that contains + the operation. For example: + + ```mlir + func @foo() -> tensor<2xf64> { + ... + toy.return %0 : tensor<2xf64> + } + ``` + }]; + + // The return operation takes an optional input operand to return. This + // value must match the return type of the enclosing function. + let arguments = (ins Variadic:$input); + + // Allow building a ReturnOp with no return operand. + let builders = [OpBuilder< + "Builder *b, OperationState &result", [{ build(b, result, llvm::None); }] + >]; + + // Provide extra utility definitions on the c++ operation class definition. + let extraClassDeclaration = [{ + bool hasOperand() { return getNumOperands() != 0; } + }]; + + // Invoke a static verify method to verify this return operation. + let verifier = [{ return ::verify(*this); }]; +} + +def TransposeOp : Toy_Op<"transpose", + [NoSideEffect, DeclareOpInterfaceMethods]> { + let summary = "transpose operation"; + + let arguments = (ins F64Tensor:$input); + let results = (outs F64Tensor); + let hasCanonicalizer = 1; + + // Allow building a TransposeOp with from the two input operands. + let builders = [ + OpBuilder<"Builder *b, OperationState &result, Value *input", [{ + buildTransposeOp(b, result, input); + }] + >]; +} + +#endif // TOY_OPS diff --git a/mlir/examples/toy/Ch5/include/toy/Passes.h b/mlir/examples/toy/Ch5/include/toy/Passes.h index 93cf0d5ba155..b6a79eda1767 100644 --- a/mlir/examples/toy/Ch5/include/toy/Passes.h +++ b/mlir/examples/toy/Ch5/include/toy/Passes.h @@ -26,10 +26,16 @@ namespace mlir { class Pass; -} // namespace mlir namespace toy { -std::unique_ptr createShapeInferencePass(); -} // namespace toy +std::unique_ptr createDeadFunctionEliminationPass(); +std::unique_ptr createShapeInferencePass(); + +/// Create a pass for lowering to operations in the `Affine` and `Std` dialects, +/// for a subset of the Toy IR (e.g. matmul). +std::unique_ptr createLowerToAffinePass(); + +} // end namespace toy +} // end namespace mlir #endif // MLIR_TUTORIAL_TOY_PASSES_H diff --git a/mlir/examples/toy/Ch5/include/toy/ShapeInferenceInterface.h b/mlir/examples/toy/Ch5/include/toy/ShapeInferenceInterface.h new file mode 100644 index 000000000000..fc36b5b100dd --- /dev/null +++ b/mlir/examples/toy/Ch5/include/toy/ShapeInferenceInterface.h @@ -0,0 +1,37 @@ +//===- ShapeInferenceInterface.h - Interface definitions for ShapeInference -=// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file contains the declarations of the shape inference interfaces defined +// in ShapeInferenceInterface.td. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_ +#define MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_ + +#include "mlir/IR/OpDefinition.h" + +namespace mlir { +namespace toy { + +/// Include the auto-generated declarations. +#include "toy/ShapeInferenceOpInterfaces.h.inc" + +} // end namespace toy +} // end namespace mlir + +#endif // MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_ diff --git a/mlir/examples/toy/Ch5/include/toy/ShapeInferenceInterface.td b/mlir/examples/toy/Ch5/include/toy/ShapeInferenceInterface.td new file mode 100644 index 000000000000..4b1240d28d57 --- /dev/null +++ b/mlir/examples/toy/Ch5/include/toy/ShapeInferenceInterface.td @@ -0,0 +1,38 @@ +//===- ShapeInferenceInterface.td - Operation Interface for Shape Inference ----------*- tablegen -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// Defines the operations of the Shape Inference Op Interface. +// +//===----------------------------------------------------------------------===// + +#ifdef SHAPE_INFERENCE_INTERFACE +#else +#define SHAPE_INFERENCE_INTERFACE + +#ifdef OP_BASE +#else +include "mlir/IR/OpBase.td" +#endif // OP_BASE + +def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> { + let methods = [ + InterfaceMethod<"Infer and set the output shape for the current operation.", + "void", "inferShapes"> + ]; +} + +#endif // SHAPE_INFERENCE_INTERFACE diff --git a/mlir/examples/toy/Ch5/mlir/DeadFunctionEliminationPass.cpp b/mlir/examples/toy/Ch5/mlir/DeadFunctionEliminationPass.cpp new file mode 100644 index 000000000000..b58adb5d52fd --- /dev/null +++ b/mlir/examples/toy/Ch5/mlir/DeadFunctionEliminationPass.cpp @@ -0,0 +1,68 @@ +//===- DeadFunctionEliminationPass.cpp - Eliminate inlined functions ------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements a Module level pass performing dead function +// elimination. This is required as a post-processing step after function +// inlining. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/Verifier.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LogicalResult.h" +#include "toy/Passes.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringSet.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +#include + +namespace { +/// This is a simple function DCE pass that deletes all non-main functions after +/// inlining. +/// TODO(riverriddle) This is only necessary because MLIR currently does not +/// have generic DCE support for functions. +class DeadFunctionEliminationPass + : public mlir::ModulePass { +public: + void runOnModule() override { + mlir::ModuleOp module = getModule(); + mlir::SymbolTable moduleSymTable(module); + + // Eliminate non-main functions. + auto mainFn = moduleSymTable.lookup("main"); + for (mlir::FuncOp func : + llvm::make_early_inc_range(module.getOps())) { + if (func != mainFn) + func.erase(); + } + } +}; +} // end anonymous namespace + +/// Create a pass that eliminates inlined functions in toy. +std::unique_ptr mlir::toy::createDeadFunctionEliminationPass() { + return std::make_unique(); +} diff --git a/mlir/examples/toy/Ch5/mlir/Dialect.cpp b/mlir/examples/toy/Ch5/mlir/Dialect.cpp new file mode 100644 index 000000000000..2ac86cf99846 --- /dev/null +++ b/mlir/examples/toy/Ch5/mlir/Dialect.cpp @@ -0,0 +1,256 @@ +//===- Dialect.cpp - Toy IR Dialect registration in MLIR ------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements the dialect for the Toy IR: custom type parsing and +// operation verification. +// +//===----------------------------------------------------------------------===// + +#include "toy/Dialect.h" + +#include "mlir/IR/Builders.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Transforms/InliningUtils.h" + +using namespace mlir; +using namespace mlir::toy; + +//===----------------------------------------------------------------------===// +// ToyInlinerInterface +//===----------------------------------------------------------------------===// + +/// This class defines the interface for handling inlining with Toy +/// operations. +struct ToyInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + //===--------------------------------------------------------------------===// + // Analysis Hooks + //===--------------------------------------------------------------------===// + + /// All operations within toy can be inlined. + bool isLegalToInline(Operation *, Region *, + BlockAndValueMapping &) const final { + return true; + } + + //===--------------------------------------------------------------------===// + // Transformation Hooks + //===--------------------------------------------------------------------===// + + /// Handle the given inlined terminator(toy.return) by replacing it with a new + /// operation as necessary. + void handleTerminator(Operation *op, + ArrayRef valuesToRepl) const final { + // Only "toy.return" needs to be handled here. + auto returnOp = cast(op); + + // Replace the values directly with the return operands. + assert(returnOp.getNumOperands() == valuesToRepl.size()); + for (const auto &it : llvm::enumerate(returnOp.getOperands())) + valuesToRepl[it.index()]->replaceAllUsesWith(it.value()); + } + + /// Attempts to materialize a conversion for a type mismatch between a call + /// from this dialect, and a callable region. This method should generate an + /// operation that takes 'input' as the only operand, and produces a single + /// result of 'resultType'. If a conversion can not be generated, nullptr + /// should be returned. + Operation *materializeCallConversion(OpBuilder &builder, Value *input, + Type resultType, + Location conversionLoc) const final { + return builder.create(conversionLoc, resultType, input); + } +}; + +//===----------------------------------------------------------------------===// +// ToyDialect +//===----------------------------------------------------------------------===// + +/// Dialect creation, the instance will be owned by the context. This is the +/// point of registration of custom types and operations for the dialect. +ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) { + addOperations< +#define GET_OP_LIST +#include "toy/Ops.cpp.inc" + >(); + addInterfaces(); +} + +//===----------------------------------------------------------------------===// +// Toy Operations +//===----------------------------------------------------------------------===// + +/// Build a constant operation. +/// The builder is passed as an argument, so is the state that this method is +/// expected to fill in order to build the operation. +static void buildConstantOp(mlir::Builder *builder, mlir::OperationState &state, + double value) { + auto dataType = builder->getTensorType({}, builder->getF64Type()); + auto dataAttribute = DenseElementsAttr::get(dataType, value); + ConstantOp::build(builder, state, dataType, dataAttribute); +} + +/// Infer the output shape of the CastOp, this is required by the shape +/// inference interface. +void CastOp::inferShapes() { getResult()->setType(getOperand()->getType()); } + +/// Verifier for the constant operation. This corresponds to the `::verify(...)` +/// in the op definition. +static mlir::LogicalResult verify(ConstantOp op) { + // If the return type of the constant is not an unranked tensor, the shape + // must match the shape of the attribute holding the data. + auto resultType = op.getResult()->getType().cast(); + if (!resultType) + return success(); + + // Check that the rank of the attribute type matches the rank of the constant + // result type. + auto attrType = op.value().getType().cast(); + if (attrType.getRank() != resultType.getRank()) { + return op.emitOpError( + "return type must match the one of the attached value " + "attribute: ") + << attrType.getRank() << " != " << resultType.getRank(); + } + + // Check that each of the dimensions match between the two types. + for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) { + if (attrType.getShape()[dim] != resultType.getShape()[dim]) { + return op.emitOpError( + "return type shape mismatches its attribute at dimension ") + << dim << ": " << attrType.getShape()[dim] + << " != " << resultType.getShape()[dim]; + } + } + return mlir::success(); +} + +static void buildAddOp(mlir::Builder *builder, mlir::OperationState &state, + mlir::Value *lhs, mlir::Value *rhs) { + state.addTypes(builder->getTensorType(builder->getF64Type())); + state.addOperands({lhs, rhs}); +} + +/// Infer the output shape of the AddOp, this is required by the shape inference +/// interface. +void AddOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); } + +static void buildGenericCallOp(mlir::Builder *builder, + mlir::OperationState &state, StringRef callee, + ArrayRef arguments) { + // Generic call always returns an unranked Tensor initially. + state.addTypes(builder->getTensorType(builder->getF64Type())); + state.addOperands(arguments); + state.addAttribute("callee", builder->getSymbolRefAttr(callee)); +} + +/// Return the callee of the generic call operation, this is required by the +/// call interface. +CallInterfaceCallable GenericCallOp::getCallableForCallee() { + return getAttrOfType("callee"); +} + +/// Get the argument operands to the called function, this is required by the +/// call interface. +Operation::operand_range GenericCallOp::getArgOperands() { return inputs(); } + +static void buildMulOp(mlir::Builder *builder, mlir::OperationState &state, + mlir::Value *lhs, mlir::Value *rhs) { + state.addTypes(builder->getTensorType(builder->getF64Type())); + state.addOperands({lhs, rhs}); +} + +/// Infer the output shape of the MulOp, this is required by the shape inference +/// interface. +void MulOp::inferShapes() { + auto lhs = getOperand(0)->getType().cast(); + auto rhs = getOperand(1)->getType().cast(); + auto lhsRank = lhs.getShape().size(); + auto rhsRank = rhs.getShape().size(); + if (lhsRank != rhsRank) + return; + + SmallVector dims; + if (lhsRank == 1) { + // dot product, result shape is <1> + dims.push_back(1); + } else if (lhsRank == 2) { + dims.push_back(lhs.getShape()[0]); + dims.push_back(rhs.getShape()[1]); + } else { + return; + } + getResult()->setType(RankedTensorType::get(dims, lhs.getElementType())); +} + +static mlir::LogicalResult verify(ReturnOp op) { + // We know that the parent operation is a function, because of the 'HasParent' + // trait attached to the operation definition. + auto function = cast(op.getParentOp()); + + /// ReturnOps can only have a single optional operand. + if (op.getNumOperands() > 1) + return op.emitOpError() << "expects at most 1 return operand"; + + // The operand number and types must match the function signature. + const auto &results = function.getType().getResults(); + if (op.getNumOperands() != results.size()) + return op.emitOpError() + << "does not return the same number of values (" + << op.getNumOperands() << ") as the enclosing function (" + << results.size() << ")"; + + // If the operation does not have an input, we are done. + if (!op.hasOperand()) + return mlir::success(); + + auto inputType = *op.operand_type_begin(); + auto resultType = results.front(); + + // Check that the result type of the function matches the operand type. + if (inputType == resultType || inputType.isa() || + resultType.isa()) + return mlir::success(); + + return op.emitError() << "type of return operand (" + << *op.operand_type_begin() + << ") doesn't match function result type (" + << results.front() << ")"; +} + +static void buildTransposeOp(mlir::Builder *builder, + mlir::OperationState &state, mlir::Value *value) { + state.addTypes(builder->getTensorType(builder->getF64Type())); + state.addOperands(value); +} + +void TransposeOp::inferShapes() { + SmallVector dims; + auto arrayTy = getOperand()->getType().cast(); + dims.insert(dims.end(), arrayTy.getShape().begin(), arrayTy.getShape().end()); + if (dims.size() == 2) + std::swap(dims[0], dims[1]); + getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType())); +} + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "toy/Ops.cpp.inc" diff --git a/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp b/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp deleted file mode 100644 index 13832f0dae0f..000000000000 --- a/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp +++ /dev/null @@ -1,148 +0,0 @@ -//=======- EarlyLowering.cpp - Toy Lowering to Linear Algebra Dialect -=======// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// This file implements early lowering of Toy IR to Linalg Dialect: we only -// lower the computationally intensive part of the program (matmul...) to a -// dialect specialized for optimizations. -// -// This is intended to showcase how multiple dialects can cohabit in the same -// function. After this lowering, you would still have toy.print in the IR for -// example. -// -//===----------------------------------------------------------------------===// - -#include "toy/Dialect.h" - -#include "linalg1/Dialect.h" -#include "linalg1/Intrinsics.h" -#include "linalg1/ViewOp.h" -#include "linalg3/TensorOps.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/EDSC/Builders.h" -#include "mlir/EDSC/Helpers.h" -#include "mlir/EDSC/Intrinsics.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/OperationSupport.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/Parser.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" -#include "llvm/ADT/DenseSet.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/IR/DerivedTypes.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/Type.h" - -#include - -using namespace mlir; - -namespace { -/// Utility function for type casting: this is making the type checker happy, -/// while delaying the actual work involved to convert the type. Most of the -/// time both side of the cast (producer and consumer) will be lowered to a -/// dialect like LLVM and end up with the same LLVM representation, at which -/// point this becomes a no-op and is eliminated. -Value *typeCast(ConversionPatternRewriter &builder, Value *val, Type destTy) { - if (val->getType() == destTy) - return val; - return builder.create(val->getLoc(), val, destTy) - .getResult(); -} - -/// Create a type cast to turn a toy.array into a memref. The Toy Array will be -/// lowered to a memref during buffer allocation, at which point the type cast -/// becomes useless. -Value *memRefTypeCast(ConversionPatternRewriter &builder, Value *val) { - if (val->getType().isa()) - return val; - auto toyArrayTy = val->getType().dyn_cast(); - if (!toyArrayTy) - return val; - return typeCast(builder, val, toyArrayTy.toMemref()); -} - -/// Lower toy.mul to Linalg `matmul`. -/// -/// This class inherit from `ConversionPattern` and override `rewrite`, -/// similarly to the PatternRewriter introduced in the previous chapter. -/// It will be called by the DialectConversion framework (see `LateLowering` -/// class below). -class MulOpConversion : public ConversionPattern { -public: - explicit MulOpConversion(MLIRContext *context) - : ConversionPattern(toy::MulOp::getOperationName(), 1, context) {} - - PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - using namespace edsc; - using intrinsics::constant_index; - using linalg::intrinsics::range; - using linalg::intrinsics::view; - toy::MulOp mul = cast(op); - auto loc = mul.getLoc(); - Value *result = memRefTypeCast( - rewriter, rewriter.create(loc, mul.getResult()->getType()) - .getResult()); - Value *lhs = memRefTypeCast(rewriter, operands[0]); - auto memrefLHSTy = lhs->getType().cast(); - Value *rhs = memRefTypeCast(rewriter, operands[1]); - auto memrefRHSTy = rhs->getType().cast(); - mlir::edsc::ScopedContext scope(rewriter, loc); - edsc::ValueHandle r0 = - range(constant_index(0), constant_index(memrefLHSTy.getDimSize(0)), - constant_index(1)); - edsc::ValueHandle r1 = - range(constant_index(0), constant_index(memrefLHSTy.getDimSize(1)), - constant_index(1)); - edsc::ValueHandle r2 = - range(constant_index(0), constant_index(memrefRHSTy.getDimSize(1)), - constant_index(1)); - auto lhsView = view(lhs, {r0, r1}); - auto rhsView = view(rhs, {r1, r2}); - auto resultView = view(result, {r0, r2}); - rewriter.create(loc, lhsView, rhsView, resultView); - rewriter.replaceOp(op, {typeCast(rewriter, result, mul.getType())}); - return matchSuccess(); - } -}; - -/// This is lowering to Linalg the parts that are computationally intensive -/// (like matmul for example...) while keeping the rest of the code in the Toy -/// dialect. -struct EarlyLoweringPass : public FunctionPass { - void runOnFunction() override { - ConversionTarget target(getContext()); - target.addLegalDialect(); - target.addLegalOp(); - - OwningRewritePatternList patterns; - patterns.insert(&getContext()); - if (failed(applyPartialConversion(getFunction(), target, patterns))) { - emitError(mlir::UnknownLoc::get(&getContext()), "Error lowering Toy\n"); - signalPassFailure(); - } - } -}; -} // end anonymous namespace - -namespace toy { -std::unique_ptr createEarlyLoweringPass() { - return std::make_unique(); -} -} // namespace toy diff --git a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp deleted file mode 100644 index cc45922f1812..000000000000 --- a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp +++ /dev/null @@ -1,470 +0,0 @@ -//====- LateLowering.cpp - Lowering from Toy+Linalg to LLVM -===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// This file implements late lowering of IR mixing Toy and Linalg to LLVM. -// It involves intemerdiate steps: -// - -// - a mix of affine and standard dialect. -// -//===----------------------------------------------------------------------===// - -#include "toy/Dialect.h" - -#include "linalg1/Dialect.h" -#include "linalg1/Intrinsics.h" -#include "linalg1/ViewOp.h" -#include "linalg3/ConvertToLLVMDialect.h" -#include "linalg3/TensorOps.h" -#include "linalg3/Transforms.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/EDSC/Builders.h" -#include "mlir/EDSC/Helpers.h" -#include "mlir/EDSC/Intrinsics.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/OperationSupport.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/Parser.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" - -#include "llvm/ADT/DenseSet.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/IR/DerivedTypes.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/Type.h" - -#include - -using namespace mlir; - -namespace { -/// Utility function for type casting: this is making the type checker happy, -/// while delaying the actual work involved to convert the type. Most of the -/// time both side of the cast (producer and consumer) will be lowered to a -/// dialect like LLVM and end up with the same LLVM representation, at which -/// point this becomes a no-op and is eliminated. -Value *typeCast(PatternRewriter &builder, Value *val, Type destTy) { - if (val->getType() == destTy) - return val; - return builder.create(val->getLoc(), val, destTy) - .getResult(); -} - -/// Create a type cast to turn a toy.array into a memref. The Toy Array will be -/// lowered to a memref during buffer allocation, at which point the type cast -/// becomes useless. -Value *memRefTypeCast(PatternRewriter &builder, Value *val) { - if (val->getType().isa()) - return val; - auto toyArrayTy = val->getType().dyn_cast(); - if (!toyArrayTy) - return val; - return typeCast(builder, val, toyArrayTy.toMemref()); -} - -/// Lower a toy.add to an affine loop nest. -/// -/// This class inherit from `ConversionPattern` and override `rewrite`, -/// similarly to the PatternRewriter introduced in the previous chapter. -/// It will be called by the DialectConversion framework (see `LateLowering` -/// class below). -class AddOpConversion : public ConversionPattern { -public: - explicit AddOpConversion(MLIRContext *context) - : ConversionPattern(toy::AddOp::getOperationName(), 1, context) {} - - /// Lower the `op` by generating IR using the `rewriter` builder. The builder - /// is setup with a new function, the `operands` array has been populated with - /// the rewritten operands for `op` in the new function. - /// The results created by the new IR with the builder are returned, and their - /// number must match the number of result of `op`. - PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - auto add = cast(op); - auto loc = add.getLoc(); - // Create a `toy.alloc` operation to allocate the output buffer for this op. - Value *result = memRefTypeCast( - rewriter, rewriter.create(loc, add.getResult()->getType()) - .getResult()); - Value *lhs = memRefTypeCast(rewriter, operands[0]); - Value *rhs = memRefTypeCast(rewriter, operands[1]); - - using namespace edsc; - ScopedContext scope(rewriter, loc); - ValueHandle zero = intrinsics::constant_index(0); - MemRefView vRes(result), vLHS(lhs), vRHS(rhs); - IndexedValue iRes(result), iLHS(lhs), iRHS(rhs); - IndexHandle i, j, M(vRes.ub(0)); - if (vRes.rank() == 1) { - LoopNestBuilder({&i}, {zero}, {M}, - {1})([&] { iRes(i) = iLHS(i) + iRHS(i); }); - } else { - assert(vRes.rank() == 2 && "only rank 1 and 2 are supported right now"); - IndexHandle N(vRes.ub(1)); - LoopNestBuilder({&i, &j}, {zero, zero}, {M, N}, - {1, 1})([&] { iRes(i, j) = iLHS(i, j) + iRHS(i, j); }); - } - - // Return the newly allocated buffer, with a type.cast to preserve the - // consumers. - rewriter.replaceOp(op, {typeCast(rewriter, result, add.getType())}); - return matchSuccess(); - } -}; - -/// Lowers `toy.print` to a loop nest calling `printf` on every individual -/// elements of the array. -class PrintOpConversion : public ConversionPattern { -public: - explicit PrintOpConversion(MLIRContext *context) - : ConversionPattern(toy::PrintOp::getOperationName(), 1, context) {} - - PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - // Get or create the declaration of the printf function in the module. - LLVM::LLVMFuncOp printfFunc = getPrintf(op->getParentOfType()); - - auto print = cast(op); - auto loc = print.getLoc(); - // We will operate on a MemRef abstraction, we use a type.cast to get one - // if our operand is still a Toy array. - Value *operand = memRefTypeCast(rewriter, operands[0]); - Type retTy = printfFunc.getType().getFunctionResultType(); - - // Create our loop nest now - using namespace edsc; - using extractvalue = intrinsics::ValueBuilder; - using llvmCall = intrinsics::ValueBuilder; - ScopedContext scope(rewriter, loc); - ValueHandle zero = intrinsics::constant_index(0); - ValueHandle fmtCst(getConstantCharBuffer(rewriter, loc, "%f ")); - MemRefView vOp(operand); - IndexedValue iOp(operand); - IndexHandle i, j, M(vOp.ub(0)); - - auto *dialect = op->getContext()->getRegisteredDialect(); - auto i8PtrTy = LLVM::LLVMType::getInt8Ty(dialect).getPointerTo(); - - ValueHandle fmtEol(getConstantCharBuffer(rewriter, loc, "\n")); - if (vOp.rank() == 1) { - // clang-format off - LoopBuilder(&i, zero, M, 1)([&]{ - llvmCall(retTy, - rewriter.getSymbolRefAttr(printfFunc), - {extractvalue(i8PtrTy, fmtCst, rewriter.getIndexArrayAttr(0)), - iOp(i)}); - }); - llvmCall(retTy, rewriter.getSymbolRefAttr(printfFunc), - {extractvalue(i8PtrTy, fmtEol, rewriter.getIndexArrayAttr(0))}); - // clang-format on - } else { - IndexHandle N(vOp.ub(1)); - // clang-format off - LoopBuilder(&i, zero, M, 1)([&]{ - LoopBuilder(&j, zero, N, 1)([&]{ - llvmCall( - retTy, - rewriter.getSymbolRefAttr(printfFunc), - {extractvalue(i8PtrTy, fmtCst, rewriter.getIndexArrayAttr(0)), - iOp(i, j)}); - }); - llvmCall( - retTy, - rewriter.getSymbolRefAttr(printfFunc), - {extractvalue(i8PtrTy, fmtEol, rewriter.getIndexArrayAttr(0))}); - }); - // clang-format on - } - rewriter.replaceOp(op, llvm::None); - return matchSuccess(); - } - -private: - // Turn a string into a toy.alloc (malloc/free abstraction) and a sequence - // of stores into the buffer, and return a MemRef into the buffer. - Value *getConstantCharBuffer(PatternRewriter &builder, Location loc, - StringRef data) const { - auto retTy = - builder.getMemRefType(data.size() + 1, builder.getIntegerType(8)); - Value *result = builder.create(loc, retTy).getResult(); - using namespace edsc; - using intrinsics::constant_index; - using intrinsics::constant_int; - ScopedContext scope(builder, loc); - MemRefView vOp(result); - IndexedValue iOp(result); - for (uint64_t i = 0; i < data.size(); ++i) { - iOp(constant_index(i)) = constant_int(data[i], 8); - } - iOp(constant_index(data.size())) = constant_int(0, 8); - return result; - } - - /// Return the prototype declaration for printf in the module, create it if - /// necessary. - LLVM::LLVMFuncOp getPrintf(ModuleOp module) const { - auto printfFunc = module.lookupSymbol("printf"); - if (printfFunc) - return printfFunc; - - // Create a function declaration for printf, signature is `i32 (i8*, ...)` - OpBuilder builder(module.getBodyRegion()); - auto *dialect = - module.getContext()->getRegisteredDialect(); - - auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(dialect); - auto llvmI8PtrTy = LLVM::LLVMType::getInt8Ty(dialect).getPointerTo(); - auto printfTy = LLVM::LLVMType::getFunctionTy(llvmI32Ty, llvmI8PtrTy, - /*isVarArg=*/true); - return builder.create(builder.getUnknownLoc(), "printf", - printfTy, - ArrayRef()); - } -}; - -/// Lowers constant to a sequence of store in a buffer. -class ConstantOpConversion : public ConversionPattern { -public: - explicit ConstantOpConversion(MLIRContext *context) - : ConversionPattern(toy::ConstantOp::getOperationName(), 1, context) {} - - PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - toy::ConstantOp cstOp = cast(op); - auto loc = cstOp.getLoc(); - auto retTy = cstOp.getResult()->getType().cast(); - auto shape = retTy.getShape(); - Value *result = memRefTypeCast( - rewriter, rewriter.create(loc, retTy).getResult()); - - auto cstValue = cstOp.getValue(); - auto f64Ty = rewriter.getF64Type(); - using namespace edsc; - using intrinsics::constant_float; - using intrinsics::constant_index; - ScopedContext scope(rewriter, loc); - MemRefView vOp(result); - IndexedValue iOp(result); - for (uint64_t i = 0, ie = shape[0]; i < ie; ++i) { - if (shape.size() == 1) { - auto value = cstValue.getValue(ArrayRef{i}); - iOp(constant_index(i)) = constant_float(value, f64Ty); - continue; - } - for (uint64_t j = 0, je = shape[1]; j < je; ++j) { - auto value = cstValue.getValue(ArrayRef{i, j}); - iOp(constant_index(i), constant_index(j)) = - constant_float(value, f64Ty); - } - } - rewriter.replaceOp(op, result); - return matchSuccess(); - } -}; - -/// Lower transpose operation to an affine loop nest. -class TransposeOpConversion : public ConversionPattern { -public: - explicit TransposeOpConversion(MLIRContext *context) - : ConversionPattern(toy::TransposeOp::getOperationName(), 1, context) {} - - PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - auto transpose = cast(op); - auto loc = transpose.getLoc(); - Value *result = memRefTypeCast( - rewriter, - rewriter.create(loc, transpose.getResult()->getType()) - .getResult()); - Value *operand = memRefTypeCast(rewriter, operands[0]); - - using namespace edsc; - ScopedContext scope(rewriter, loc); - ValueHandle zero = intrinsics::constant_index(0); - MemRefView vRes(result), vOperand(operand); - IndexedValue iRes(result), iOperand(operand); - IndexHandle i, j, M(vRes.ub(0)), N(vRes.ub(1)); - // clang-format off - LoopNestBuilder({&i, &j}, {zero, zero}, {M, N}, {1, 1})([&]{ - iRes(i, j) = iOperand(j, i); - }); - // clang-format on - - rewriter.replaceOp(op, {typeCast(rewriter, result, transpose.getType())}); - return matchSuccess(); - } -}; - -// Lower toy.return to standard return operation. -class ReturnOpConversion : public ConversionPattern { -public: - explicit ReturnOpConversion(MLIRContext *context) - : ConversionPattern(toy::ReturnOp::getOperationName(), 1, context) {} - - PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - // Argument is optional, handle both cases. - if (op->getNumOperands()) - rewriter.replaceOpWithNewOp(op, operands[0]); - else - rewriter.replaceOpWithNewOp(op); - return matchSuccess(); - } -}; - -/// This is the main class registering our individual converter classes with -/// the DialectConversion framework in MLIR. -class ToyTypeConverter : public TypeConverter { -protected: - /// Convert a Toy type, this gets called for block and region arguments, and - /// attributes. - Type convertType(Type t) override { - if (auto array = t.dyn_cast()) - return array.toMemref(); - return t; - } - - /// Materialize a conversion to allow for partial lowering of types. - Operation *materializeConversion(PatternRewriter &rewriter, Type resultType, - ArrayRef inputs, - Location loc) override { - assert(inputs.size() == 1 && "expected only one input value"); - return rewriter.create(loc, inputs[0], resultType); - } -}; - -/// This is lowering to Linalg the parts that can be (matmul and add on arrays) -/// and is targeting LLVM otherwise. -struct LateLoweringPass : public ModulePass { - void runOnModule() override { - ToyTypeConverter typeConverter; - OwningRewritePatternList toyPatterns; - toyPatterns.insert( - &getContext()); - mlir::populateFuncOpTypeConversionPattern(toyPatterns, &getContext(), - typeConverter); - - // Perform Toy specific lowering. - ConversionTarget target(getContext()); - target.addLegalDialect(); - target.addDynamicallyLegalOp([&](FuncOp op) { - return typeConverter.isSignatureLegal(op.getType()); - }); - target.addLegalOp(); - if (failed(applyPartialConversion(getModule(), target, toyPatterns, - &typeConverter))) { - emitError(UnknownLoc::get(getModule().getContext()), - "error lowering Toy\n"); - signalPassFailure(); - } - - // At this point the IR is almost using only standard and affine dialects. - // A few things remain before we emit LLVM IR. First to reuse as much of - // MLIR as possible we will try to lower everything to the standard and/or - // affine dialect: they already include conversion to the LLVM dialect. - - // First patch calls type to return memref instead of ToyArray - for (auto function : getModule().getOps()) { - function.walk([&](Operation *op) { - auto callOp = dyn_cast(op); - if (!callOp) - return; - if (!callOp.getNumResults()) - return; - auto retToyTy = - callOp.getResult(0)->getType().dyn_cast(); - if (!retToyTy) - return; - callOp.getResult(0)->setType(retToyTy.toMemref()); - }); - } - - for (auto function : getModule().getOps()) { - function.walk([&](Operation *op) { - // Turns toy.alloc into sequence of alloc/dealloc (later malloc/free). - if (auto allocOp = dyn_cast(op)) { - auto result = allocTensor(allocOp); - allocOp.replaceAllUsesWith(result); - allocOp.erase(); - return; - } - // Eliminate all type.cast before lowering to LLVM. - if (auto typeCastOp = dyn_cast(op)) { - typeCastOp.replaceAllUsesWith(typeCastOp.getOperand()); - typeCastOp.erase(); - return; - } - }); - } - - // Lower Linalg to affine - for (auto function : getModule().getOps()) - linalg::lowerToLoops(function); - - getModule().dump(); - - // Finally convert to LLVM Dialect - linalg::convertLinalg3ToLLVM(getModule()); - } - - /// Allocate buffers (malloc/free) for Toy operations. This can't be done as - /// part of dialect conversion framework since we need to insert `dealloc` - /// operations just before the return, but the conversion framework is - /// operating in a brand new function: we don't have the return to hook the - /// dealloc operations. - Value *allocTensor(toy::AllocOp alloc) { - OpBuilder builder(alloc); - auto retTy = alloc.getResult()->getType(); - - auto memRefTy = retTy.dyn_cast(); - if (!memRefTy) - memRefTy = retTy.cast().toMemref(); - if (!memRefTy) { - alloc.emitOpError("is expected to allocate a Toy array or a MemRef"); - llvm_unreachable("fatal error"); - } - auto loc = alloc.getLoc(); - Value *result = builder.create(loc, memRefTy).getResult(); - - // Insert a `dealloc` operation right before the `return` operations, unless - // it is returned itself in which case the caller is responsible for it. - alloc.getParentRegion()->walk([&](Operation *op) { - auto returnOp = dyn_cast(op); - if (!returnOp) - return; - if (returnOp.getNumOperands() && returnOp.getOperand(0) == alloc) - return; - builder.setInsertionPoint(returnOp); - builder.create(alloc.getLoc(), result); - }); - return result; - } -}; -} // end anonymous namespace - -namespace toy { -std::unique_ptr createLateLoweringPass() { - return std::make_unique(); -} -} // namespace toy diff --git a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp new file mode 100644 index 000000000000..a8e38aef7ad3 --- /dev/null +++ b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp @@ -0,0 +1,318 @@ +//====- LowerToAffineLoops.cpp - Partial lowering from Toy to Affine+Std --===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements a partial lowering of Toy operations to a combination of +// affine loops and standard operations. This lowering expects that all calls +// have been inlined, and all shapes have been resolved. +// +//===----------------------------------------------------------------------===// + +#include "toy/Dialect.h" +#include "toy/Passes.h" + +#include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/Sequence.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// ToyToAffine RewritePatterns +//===----------------------------------------------------------------------===// + +/// Convert the given TensorType into the corresponding MemRefType. +static MemRefType convertTensorToMemRef(TensorType type) { + assert(type.hasRank() && "expected only ranked shapes"); + return MemRefType::get(type.getShape(), type.getElementType()); +} + +/// Insert an allocation and deallocation for the given MemRefType. +static Value *insertAllocAndDealloc(MemRefType type, Location loc, + PatternRewriter &rewriter) { + auto alloc = rewriter.create(loc, type); + + // Make sure to allocate at the beginning of the block. + auto *parentBlock = alloc.getOperation()->getBlock(); + alloc.getOperation()->moveBefore(&parentBlock->front()); + + // Make sure to deallocate this alloc at the end of the block. This is fine + // as toy functions have no control flow. + auto dealloc = rewriter.create(loc, alloc); + dealloc.getOperation()->moveBefore(&parentBlock->back()); + return alloc; +} + +/// This defines the function type used to process an iteration of a lowered +/// loop. It takes as input a rewriter, an array of memRefOperands corresponding +/// to the operands of the input operation, and the set of loop induction +/// variables for the iteration. It returns a value to store at the current +/// index of the iteration. +using LoopIterationFn = function_ref memRefOperands, + ArrayRef loopIvs)>; + +static void lowerOpToLoops(Operation *op, ArrayRef operands, + PatternRewriter &rewriter, + LoopIterationFn processIteration) { + auto tensorType = (*op->result_type_begin()).cast(); + auto loc = op->getLoc(); + + // Insert an allocation and deallocation for the result of this operation. + auto memRefType = convertTensorToMemRef(tensorType); + auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter); + + // Create an empty affine loop for each of the dimensions within the shape. + SmallVector loopIvs; + for (auto dim : tensorType.getShape()) { + auto loop = rewriter.create(loc, /*lb=*/0, dim, /*step=*/1); + loop.getBody()->clear(); + loopIvs.push_back(loop.getInductionVar()); + + // Terminate the loop body and update the rewriter insertion point to the + // beginning of the loop. + rewriter.setInsertionPointToStart(loop.getBody()); + rewriter.create(loc); + rewriter.setInsertionPointToStart(loop.getBody()); + } + + // Generate a call to the processing function with the rewriter, the memref + // operands, and the loop induction variables. This function will return the + // value to store at the current index. + Value *valueToStore = processIteration(rewriter, operands, loopIvs); + rewriter.create(loc, valueToStore, alloc, + llvm::makeArrayRef(loopIvs)); + + // Replace this operation with the generated alloc. + rewriter.replaceOp(op, alloc); +} + +namespace { +//===----------------------------------------------------------------------===// +// ToyToAffine RewritePatterns: Binary operations +//===----------------------------------------------------------------------===// + +template +struct BinaryOpLowering : public ConversionPattern { + BinaryOpLowering(MLIRContext *ctx) + : ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + auto loc = op->getLoc(); + lowerOpToLoops( + op, operands, rewriter, + [loc](PatternRewriter &rewriter, ArrayRef memRefOperands, + ArrayRef loopIvs) { + // Generate an adaptor for the remapped operands of the BinaryOp. This + // allows for using the nice named accessors that are generated by the + // ODS. + typename BinaryOp::OperandAdaptor binaryAdaptor(memRefOperands); + + // Generate loads for the element of 'lhs' and 'rhs' at the inner + // loop. + auto loadedLhs = + rewriter.create(loc, binaryAdaptor.lhs(), loopIvs); + auto loadedRhs = + rewriter.create(loc, binaryAdaptor.rhs(), loopIvs); + + // Create the binary operation performed on the loaded values. + return rewriter.create(loc, loadedLhs, loadedRhs); + }); + return matchSuccess(); + } +}; +using AddOpLowering = BinaryOpLowering; +using MulOpLowering = BinaryOpLowering; + +//===----------------------------------------------------------------------===// +// ToyToAffine RewritePatterns: Constant operations +//===----------------------------------------------------------------------===// + +struct ConstantOpLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(toy::ConstantOp op, + PatternRewriter &rewriter) const final { + DenseElementsAttr constantValue = op.value(); + Location loc = op.getLoc(); + + // When lowering the constant operation, we allocate and assign the constant + // values to a corresponding memref allocation. + auto tensorType = op.getType().cast(); + auto memRefType = convertTensorToMemRef(tensorType); + auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter); + + // We will be generating constant indices up-to the largest dimension. + // Create these constants up-front to avoid large amounts of redundant + // operations. + auto valueShape = memRefType.getShape(); + SmallVector constantIndices; + for (auto i : llvm::seq( + 0, *std::max_element(valueShape.begin(), valueShape.end()))) + constantIndices.push_back(rewriter.create(loc, i)); + + // The constant operation represents a multi-dimensional constant, so we + // will need to generate a store for each of the elements. The following + // functor recursively walks the dimensions of the constant shape, + // generating a store when the recursion hits the base case. + SmallVector indices; + auto valueIt = constantValue.getValues().begin(); + std::function storeElements = [&](uint64_t dimension) { + // The last dimension is the base case of the recursion, at this point + // we store the element at the given index. + if (dimension == valueShape.size()) { + rewriter.create( + loc, rewriter.create(loc, *valueIt++), alloc, + llvm::makeArrayRef(indices)); + return; + } + + // Otherwise, iterate over the current dimension and add the indices to + // the list. + for (uint64_t i = 0, e = valueShape[dimension]; i != e; ++i) { + indices.push_back(constantIndices[i]); + storeElements(dimension + 1); + indices.pop_back(); + } + }; + + // Start the element storing recursion from the first dimension. + storeElements(/*dimension=*/0); + + // Replace this operation with the generated alloc. + rewriter.replaceOp(op, alloc); + return matchSuccess(); + } +}; + +//===----------------------------------------------------------------------===// +// ToyToAffine RewritePatterns: Return operations +//===----------------------------------------------------------------------===// + +struct ReturnOpLowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(toy::ReturnOp op, + PatternRewriter &rewriter) const final { + // During this lowering, we expect that all function calls have been + // inlined. + if (op.hasOperand()) + return matchFailure(); + + // We lower "toy.return" directly to "std.return". + rewriter.replaceOpWithNewOp(op); + return matchSuccess(); + } +}; + +//===----------------------------------------------------------------------===// +// ToyToAffine RewritePatterns: Transpose operations +//===----------------------------------------------------------------------===// + +struct TransposeOpLowering : public ConversionPattern { + TransposeOpLowering(MLIRContext *ctx) + : ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + auto loc = op->getLoc(); + lowerOpToLoops( + op, operands, rewriter, + [loc](PatternRewriter &rewriter, ArrayRef memRefOperands, + ArrayRef loopIvs) { + // Generate an adaptor for the remapped operands of the TransposeOp. + // This allows for using the nice named accessors that are generated + // by the ODS. + toy::TransposeOpOperandAdaptor tranposeAdaptor(memRefOperands); + Value *input = tranposeAdaptor.input(); + + // Transpose the elements by generating a load from the reverse + // indices. + SmallVector reverseIvs(llvm::reverse(loopIvs)); + return rewriter.create(loc, input, reverseIvs); + }); + return matchSuccess(); + } +}; + +} // end anonymous namespace. + +//===----------------------------------------------------------------------===// +// ToyToAffineLoweringPass +//===----------------------------------------------------------------------===// + +/// This is a partial lowering to affine loops of the toy operations that are +/// computationally intensive (like matmul for example...) while keeping the +/// rest of the code in the Toy dialect. +namespace { +struct ToyToAffineLoweringPass : public FunctionPass { + void runOnFunction() final; +}; +} // end anonymous namespace. + +void ToyToAffineLoweringPass::runOnFunction() { + auto function = getFunction(); + + // We only lower the main function as we expect that all other functions have + // been inlined. + if (function.getName() != "main") + return; + + // Verify that the given main has no inputs and results. + if (function.getNumArguments() || function.getType().getNumResults()) { + function.emitError("expected 'main' to have 0 inputs and 0 results"); + return signalPassFailure(); + } + + // The first thing to define is the conversion target. This will define the + // final target for this lowering. + ConversionTarget target(getContext()); + + // We define the specific operations, or dialects, that are legal targets for + // this lowering. In our case, we are lowering to a combination of the + // `Affine` and `Standard` dialects. + target.addLegalDialect(); + + // We also define the Toy dialect as Illegal so that the conversion will fail + // if any of these operations are *not* converted. Given that we actually want + // a partial lowering, we explicitly mark the Toy operations that don't want + // to lower, `toy.print`, as `legal`. + target.addIllegalDialect(); + target.addLegalOp(); + + // Now that the conversion target has been defined, we just need to provide + // the set of patterns that will lower the Toy operations. + OwningRewritePatternList patterns; + patterns.insert(&getContext()); + + // With the target and rewrite patterns defined, we can now attempt the + // conversion. The conversion will signal failure if any of our `illegal` + // operations were not converted successfully. + if (failed(applyPartialConversion(getFunction(), target, patterns))) + signalPassFailure(); +} + +/// Create a pass for lowering operations in the `Affine` and `Std` dialects, +/// for a subset of the Toy IR (e.g. matmul). +std::unique_ptr mlir::toy::createLowerToAffinePass() { + return std::make_unique(); +} diff --git a/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp index 8f80da9ee972..5f12d0a8798a 100644 --- a/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch5/mlir/MLIRGen.cpp @@ -25,30 +25,30 @@ #include "toy/Dialect.h" #include "mlir/Analysis/Verifier.h" -#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" -#include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" #include "mlir/IR/StandardTypes.h" -#include "mlir/IR/Types.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopedHashTable.h" #include "llvm/Support/raw_ostream.h" #include +using namespace mlir::toy; using namespace toy; + +using llvm::ArrayRef; using llvm::cast; using llvm::dyn_cast; using llvm::isa; +using llvm::makeArrayRef; using llvm::ScopedHashTableScope; using llvm::SmallVector; using llvm::StringRef; using llvm::Twine; -using std::make_unique; namespace { @@ -57,56 +57,43 @@ namespace { /// This will emit operations that are specific to the Toy language, preserving /// the semantics of the language and (hopefully) allow to perform accurate /// analysis and transformation based on these high level semantics. -/// -/// At this point we take advantage of the "raw" MLIR APIs to create operations -/// that haven't been registered in any way with MLIR. These operations are -/// unknown to MLIR, custom passes could operate by string-matching the name of -/// these operations, but no other type checking or semantic is associated with -/// them natively by MLIR. class MLIRGenImpl { public: - MLIRGenImpl(mlir::MLIRContext &context) : context(context) {} + MLIRGenImpl(mlir::MLIRContext &context) : builder(&context) {} /// Public API: convert the AST for a Toy module (source file) to an MLIR - /// Module. - mlir::OwningModuleRef mlirGen(ModuleAST &moduleAST) { + /// Module operation. + mlir::ModuleOp mlirGen(ModuleAST &moduleAST) { // We create an empty MLIR module and codegen functions one at a time and // add them to the module. - theModule = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)); + theModule = mlir::ModuleOp::create(builder.getUnknownLoc()); for (FunctionAST &F : moduleAST) { auto func = mlirGen(F); if (!func) return nullptr; - theModule->push_back(func); + theModule.push_back(func); } - // FIXME: (in the next chapter...) without registering a dialect in MLIR, - // this won't do much, but it should at least check some structural - // properties. - if (failed(mlir::verify(*theModule))) { - emitError(mlir::UnknownLoc::get(&context), "module verification error"); + // Verify the module after we have finished constructing it, this will check + // the structural properties of the IR and invoke any specific verifiers we + // have on the Toy operations. + if (failed(mlir::verify(theModule))) { + theModule.emitError("module verification error"); return nullptr; } - return std::move(theModule); + return theModule; } private: - /// In MLIR (like in LLVM) a "context" object holds the memory allocation and - /// the ownership of many internal structure of the IR and provide a level - /// of "uniquing" across multiple modules (types for instance). - mlir::MLIRContext &context; + /// A "module" matches a Toy source file: containing a list of functions. + mlir::ModuleOp theModule; - /// A "module" matches a source file: it contains a list of functions. - mlir::OwningModuleRef theModule; - - /// The builder is a helper class to create IR inside a function. It is - /// re-initialized every time we enter a function and kept around as a - /// convenience for emitting individual operations. - /// The builder is stateful, in particular it keeeps an "insertion point": - /// this is where the next operations will be introduced. - std::unique_ptr builder; + /// The builder is a helper class to create IR inside a function. The builder + /// is stateful, in particular it keeeps an "insertion point": this is where + /// the next operations will be introduced. + mlir::OpBuilder builder; /// The symbol table maps a variable name to a value in the current scope. /// Entering a function creates a new scope, and the function arguments are @@ -116,37 +103,35 @@ private: /// Helper conversion for a Toy AST location to an MLIR location. mlir::Location loc(Location loc) { - return mlir::FileLineColLoc::get(mlir::Identifier::get(*loc.file, &context), - loc.line, loc.col, &context); + return builder.getFileLineColLoc(builder.getIdentifier(*loc.file), loc.line, + loc.col); } - /// Declare a variable in the current scope, return true if the variable + /// Declare a variable in the current scope, return success if the variable /// wasn't declared yet. - bool declare(llvm::StringRef var, mlir::Value *value) { - if (symbolTable.count(var)) { - return false; - } + mlir::LogicalResult declare(llvm::StringRef var, mlir::Value *value) { + if (symbolTable.count(var)) + return mlir::failure(); symbolTable.insert(var, value); - return true; + return mlir::success(); } /// Create the prototype for an MLIR function with as many arguments as the /// provided Toy AST prototype. mlir::FuncOp mlirGen(PrototypeAST &proto) { + auto location = loc(proto.loc()); + // This is a generic function, the return type will be inferred later. - llvm::SmallVector ret_types; - // Arguments type is uniformly a generic array. + // Arguments type are uniformly unranked tensors. llvm::SmallVector arg_types(proto.getArgs().size(), getType(VarType{})); - auto func_type = mlir::FunctionType::get(arg_types, ret_types, &context); - auto function = mlir::FuncOp::create(loc(proto.loc()), proto.getName(), - func_type, /* attrs = */ {}); + auto func_type = builder.getFunctionType(arg_types, llvm::None); + auto function = mlir::FuncOp::create(location, proto.getName(), func_type); // Mark the function as generic: it'll require type specialization for every // call site. if (function.getNumArguments()) - function.setAttr("toy.generic", mlir::BoolAttr::get(true, &context)); - + function.setAttr("toy.generic", builder.getUnitAttr()); return function; } @@ -165,29 +150,39 @@ private: // argument list as the function itself. auto &entryBlock = *function.addEntryBlock(); auto &protoArgs = funcAST.getProto()->getArgs(); + // Declare all the function arguments in the symbol table. for (const auto &name_value : llvm::zip(protoArgs, entryBlock.getArguments())) { - declare(std::get<0>(name_value)->getName(), std::get<1>(name_value)); + if (failed(declare(std::get<0>(name_value)->getName(), + std::get<1>(name_value)))) + return nullptr; } - // Create a builder for the function, it will be used throughout the codegen - // to create operations in this function. - builder = std::make_unique(function.getBody()); + // Set the insertion point in the builder to the beginning of the function + // body, it will be used throughout the codegen to create operations in this + // function. + builder.setInsertionPointToStart(&entryBlock); // Emit the body of the function. - if (!mlirGen(*funcAST.getBody())) { + if (mlir::failed(mlirGen(*funcAST.getBody()))) { function.erase(); return nullptr; } - // Implicitly return void if no return statement was emited. + // Implicitly return void if no return statement was emitted. // FIXME: we may fix the parser instead to always return the last expression // (this would possibly help the REPL case later) - if (function.getBlocks().back().back().getName().getStringRef() != - "toy.return") { - ReturnExprAST fakeRet(funcAST.getProto()->loc(), llvm::None); - mlirGen(fakeRet); + ReturnOp returnOp; + if (!entryBlock.empty()) + returnOp = dyn_cast(entryBlock.back()); + if (!returnOp) { + builder.create(loc(funcAST.getProto()->loc())); + } else if (returnOp.hasOperand()) { + // Otherwise, if this return operation has an operand then add a result to + // the function. + function.setType(builder.getFunctionType(function.getType().getInputs(), + getType(VarType{}))); } return function; @@ -206,11 +201,11 @@ private: // and the result value is returned. If an error occurs we get a nullptr // and propagate. // - mlir::Value *L = mlirGen(*binop.getLHS()); - if (!L) + mlir::Value *lhs = mlirGen(*binop.getLHS()); + if (!lhs) return nullptr; - mlir::Value *R = mlirGen(*binop.getRHS()); - if (!R) + mlir::Value *rhs = mlirGen(*binop.getRHS()); + if (!rhs) return nullptr; auto location = loc(binop.loc()); @@ -218,123 +213,113 @@ private: // support '+' and '*'. switch (binop.getOp()) { case '+': - return builder->create(location, L, R).getResult(); - break; + return builder.create(location, lhs, rhs); case '*': - return builder->create(location, L, R).getResult(); - default: - emitError(loc(binop.loc()), "error: invalid binary operator '") - << binop.getOp() << "'"; - return nullptr; + return builder.create(location, lhs, rhs); } + + emitError(location, "invalid binary operator '") << binop.getOp() << "'"; + return nullptr; } - // This is a reference to a variable in an expression. The variable is - // expected to have been declared and so should have a value in the symbol - // table, otherwise emit an error and return nullptr. + /// This is a reference to a variable in an expression. The variable is + /// expected to have been declared and so should have a value in the symbol + /// table, otherwise emit an error and return nullptr. mlir::Value *mlirGen(VariableExprAST &expr) { - if (symbolTable.count(expr.getName())) - return symbolTable.lookup(expr.getName()); + if (auto *variable = symbolTable.lookup(expr.getName())) + return variable; + emitError(loc(expr.loc()), "error: unknown variable '") << expr.getName() << "'"; return nullptr; } - // Emit a return operation, return true on success. - bool mlirGen(ReturnExprAST &ret) { + /// Emit a return operation. This will return failure if any generation fails. + mlir::LogicalResult mlirGen(ReturnExprAST &ret) { auto location = loc(ret.loc()); - // `return` takes an optional expression, we need to account for it here. - if (!ret.getExpr().hasValue()) { - builder->create(location); - return true; + + // 'return' takes an optional expression, handle that case here. + mlir::Value *expr = nullptr; + if (ret.getExpr().hasValue()) { + if (!(expr = mlirGen(*ret.getExpr().getValue()))) + return mlir::failure(); } - auto *expr = mlirGen(*ret.getExpr().getValue()); - if (!expr) - return false; - builder->create(location, expr); - return true; + + // Otherwise, this return operation has zero operands. + builder.create(location, expr ? makeArrayRef(expr) + : ArrayRef()); + return mlir::success(); } - // Emit a literal/constant array. It will be emitted as a flattened array of - // data in an Attribute attached to a `toy.constant` operation. - // See documentation on [Attributes](LangRef.md#attributes) for more details. - // Here is an excerpt: - // - // Attributes are the mechanism for specifying constant data in MLIR in - // places where a variable is never allowed [...]. They consist of a name - // and a [concrete attribute value](#attribute-values). It is possible to - // attach attributes to operations, functions, and function arguments. The - // set of expected attributes, their structure, and their interpretation - // are all contextually dependent on what they are attached to. - // - // Example, the source level statement: - // var a<2, 3> = [[1, 2, 3], [4, 5, 6]]; - // will be converted to: - // %0 = "toy.constant"() {value: dense, - // [[1.000000e+00, 2.000000e+00, 3.000000e+00], - // [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> memref<2x3xf64> - // + /// Emit a literal/constant array. It will be emitted as a flattened array of + /// data in an Attribute attached to a `toy.constant` operation. + /// See documentation on [Attributes](LangRef.md#attributes) for more details. + /// Here is an excerpt: + /// + /// Attributes are the mechanism for specifying constant data in MLIR in + /// places where a variable is never allowed [...]. They consist of a name + /// and a concrete attribute value. The set of expected attributes, their + /// structure, and their interpretation are all contextually dependent on + /// what they are attached to. + /// + /// Example, the source level statement: + /// var a<2, 3> = [[1, 2, 3], [4, 5, 6]]; + /// will be converted to: + /// %0 = "toy.constant"() {value: dense, + /// [[1.000000e+00, 2.000000e+00, 3.000000e+00], + /// [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> tensor<2x3xf64> + /// mlir::Value *mlirGen(LiteralExprAST &lit) { - auto location = loc(lit.loc()); - // The attribute is a vector with an attribute per element (number) in the - // array, see `collectData()` below for more details. - std::vector data; + auto type = getType(lit.getDims()); + + // The attribute is a vector with a floating point value per element + // (number) in the array, see `collectData()` below for more details. + std::vector data; data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1, std::multiplies())); collectData(lit, data); - // FIXME: using a tensor type is a HACK here. - // Can we do differently without registering a dialect? Using a string blob? - mlir::Type elementType = mlir::FloatType::getF64(&context); - auto dataType = builder->getTensorType(lit.getDims(), elementType); + // The type of this attribute is tensor of 64-bit floating-point with the + // shape of the literal. + mlir::Type elementType = builder.getF64Type(); + auto dataType = builder.getTensorType(lit.getDims(), elementType); - // This is the actual attribute that actually hold the list of values for - // this array literal. - auto dataAttribute = builder->getDenseElementsAttr(dataType, data) - .cast(); + // This is the actual attribute that holds the list of values for this + // tensor literal. + auto dataAttribute = + mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(data)); - // Build the MLIR op `toy.constant`, only boilerplate below. - return builder->create(location, lit.getDims(), dataAttribute) - .getResult(); + // Build the MLIR op `toy.constant`. This invokes the `ConstantOp::build` + // method. + return builder.create(loc(lit.loc()), type, dataAttribute); } - // Recursive helper function to accumulate the data that compose an array - // literal. It flattens the nested structure in the supplied vector. For - // example with this array: - // [[1, 2], [3, 4]] - // we will generate: - // [ 1, 2, 3, 4 ] - // Individual numbers are wrapped in a light wrapper `mlir::FloatAttr`. - // Attributes are the way MLIR attaches constant to operations and functions. - void collectData(ExprAST &expr, std::vector &data) { + /// Recursive helper function to accumulate the data that compose an array + /// literal. It flattens the nested structure in the supplied vector. For + /// example with this array: + /// [[1, 2], [3, 4]] + /// we will generate: + /// [ 1, 2, 3, 4 ] + /// Individual numbers are represented as doubles. + /// Attributes are the way MLIR attaches constant to operations. + void collectData(ExprAST &expr, std::vector &data) { if (auto *lit = dyn_cast(&expr)) { for (auto &value : lit->getValues()) collectData(*value, data); return; } + assert(isa(expr) && "expected literal or number expr"); - mlir::Type elementType = mlir::FloatType::getF64(&context); - auto attr = mlir::FloatAttr::getChecked( - elementType, cast(expr).getValue(), loc(expr.loc())); - data.push_back(attr); + data.push_back(cast(expr).getValue()); } - // Emit a call expression. It emits specific operations for the `transpose` - // builtin. Other identifiers are assumed to be user-defined functions. + /// Emit a call expression. It emits specific operations for the `transpose` + /// builtin. Other identifiers are assumed to be user-defined functions. mlir::Value *mlirGen(CallExprAST &call) { + llvm::StringRef callee = call.getCallee(); auto location = loc(call.loc()); - std::string callee = call.getCallee(); - if (callee == "transpose") { - if (call.getArgs().size() != 1) { - emitError(location, "MLIR codegen encountered an error: toy.transpose " - "does not accept multiple arguments"); - return nullptr; - } - mlir::Value *arg = mlirGen(*call.getArgs()[0]); - return builder->create(location, arg).getResult(); - } - // Codegen the operands first + // Codegen the operands first. SmallVector operands; for (auto &expr : call.getArgs()) { auto *arg = mlirGen(*expr); @@ -342,34 +327,41 @@ private: return nullptr; operands.push_back(arg); } - // Calls to user-defined function are mapped to a custom call that takes - // the callee name as an attribute. - return builder->create(location, call.getCallee(), operands) - .getResult(); + + // Builting calls have their custom operation, meaning this is a + // straightforward emission. + if (callee == "transpose") { + if (call.getArgs().size() != 1) { + emitError(location, "MLIR codegen encountered an error: toy.transpose " + "does not accept multiple arguments"); + return nullptr; + } + return builder.create(location, operands[0]); + } + + // Otherwise this is a call to a user-defined function. Calls to ser-defined + // functions are mapped to a custom call that takes the callee name as an + // attribute. + return builder.create(location, callee, operands); } - // Emit a call expression. It emits specific operations for two builtins: - // transpose(x) and print(x). Other identifiers are assumed to be user-defined - // functions. Return false on failure. - bool mlirGen(PrintExprAST &call) { + /// Emit a print expression. It emits specific operations for two builtins: + /// transpose(x) and print(x). + mlir::LogicalResult mlirGen(PrintExprAST &call) { auto *arg = mlirGen(*call.getArg()); if (!arg) - return false; - auto location = loc(call.loc()); - builder->create(location, arg); - return true; + return mlir::failure(); + + builder.create(loc(call.loc()), arg); + return mlir::success(); } - // Emit a constant for a single number (FIXME: semantic? broadcast?) + /// Emit a constant for a single number (FIXME: semantic? broadcast?) mlir::Value *mlirGen(NumberExprAST &num) { - auto location = loc(num.loc()); - mlir::Type elementType = mlir::FloatType::getF64(&context); - auto attr = mlir::FloatAttr::getChecked(elementType, num.getValue(), - loc(num.loc())); - return builder->create(location, attr).getResult(); + return builder.create(loc(num.loc()), num.getValue()); } - // Dispatch codegen for the right expression subclass using RTTI. + /// Dispatch codegen for the right expression subclass using RTTI. mlir::Value *mlirGen(ExprAST &expr) { switch (expr.getKind()) { case toy::ExprAST::Expr_BinOp: @@ -383,84 +375,82 @@ private: case toy::ExprAST::Expr_Num: return mlirGen(cast(expr)); default: - emitError(loc(expr.loc()), - "MLIR codegen encountered an unhandled expr kind '") + emitError(loc(expr.loc())) + << "MLIR codegen encountered an unhandled expr kind '" << Twine(expr.getKind()) << "'"; return nullptr; } } - // Handle a variable declaration, we'll codegen the expression that forms the - // initializer and record the value in the symbol table before returning it. - // Future expressions will be able to reference this variable through symbol - // table lookup. + /// Handle a variable declaration, we'll codegen the expression that forms the + /// initializer and record the value in the symbol table before returning it. + /// Future expressions will be able to reference this variable through symbol + /// table lookup. mlir::Value *mlirGen(VarDeclExprAST &vardecl) { - mlir::Value *value = nullptr; - auto location = loc(vardecl.loc()); - if (auto init = vardecl.getInitVal()) { - value = mlirGen(*init); - if (!value) - return nullptr; - // We have the initializer value, but in case the variable was declared - // with specific shape, we emit a "reshape" operation. It will get - // optimized out later as needed. - if (!vardecl.getType().shape.empty()) { - value = builder - ->create( - location, value, - getType(vardecl.getType()).cast()) - .getResult(); - } - } else { + auto init = vardecl.getInitVal(); + if (!init) { emitError(loc(vardecl.loc()), "missing initializer in variable declaration"); return nullptr; } - // Register the value in the symbol table - declare(vardecl.getName(), value); + + mlir::Value *value = mlirGen(*init); + if (!value) + return nullptr; + + // We have the initializer value, but in case the variable was declared + // with specific shape, we emit a "reshape" operation. It will get + // optimized out later as needed. + if (!vardecl.getType().shape.empty()) { + value = builder.create(loc(vardecl.loc()), + getType(vardecl.getType()), value); + } + + // Register the value in the symbol table. + if (failed(declare(vardecl.getName(), value))) + return nullptr; return value; } - /// Codegen a list of expression, return false if one of them hit an error. - bool mlirGen(ExprASTList &blockAST) { - ScopedHashTableScope var_scope(symbolTable); + /// Codegen a list of expression, return failure if one of them hit an error. + mlir::LogicalResult mlirGen(ExprASTList &blockAST) { + ScopedHashTableScope var_scope(symbolTable); for (auto &expr : blockAST) { // Specific handling for variable declarations, return statement, and // print. These can only appear in block list and not in nested // expressions. if (auto *vardecl = dyn_cast(expr.get())) { if (!mlirGen(*vardecl)) - return false; + return mlir::failure(); continue; } - if (auto *ret = dyn_cast(expr.get())) { - if (!mlirGen(*ret)) - return false; - return true; - } + if (auto *ret = dyn_cast(expr.get())) + return mlirGen(*ret); if (auto *print = dyn_cast(expr.get())) { - if (!mlirGen(*print)) - return false; + if (mlir::failed(mlirGen(*print))) + return mlir::success(); continue; } + // Generic expression dispatch codegen. if (!mlirGen(*expr)) - return false; + return mlir::failure(); } - return true; + return mlir::success(); } - /// Build a type from a list of shape dimensions. Types are `array` followed - /// by an optional dimension list, example: array<2, 2> - /// They are wrapped in a `toy` dialect (see next chapter) and get printed: - /// !toy.array<2, 2> - template mlir::Type getType(T shape) { - SmallVector shape64(shape.begin(), shape.end()); - return ToyArrayType::get(&context, shape64); + /// Build a tensor type from a list of shape dimensions. + mlir::Type getType(ArrayRef shape) { + // If the shape is empty, then this type is unranked. + if (shape.empty()) + return builder.getTensorType(builder.getF64Type()); + + // Otherwise, we use the given shape. + return builder.getTensorType(shape, builder.getF64Type()); } - /// Build an MLIR type from a Toy AST variable type - /// (forward to the generic getType(T) above). + /// Build an MLIR type from a Toy AST variable type (forward to the generic + /// getType above). mlir::Type getType(const VarType &type) { return getType(type.shape); } }; diff --git a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp index 527afa71dae4..1f572015c39e 100644 --- a/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp @@ -1,4 +1,4 @@ -//===- ShapeInferencePass.cpp - Toy Shape Inference / Func Specialization -===// +//===- ShapeInferencePass.cpp - Shape Inference ---------------------------===// // // Copyright 2019 The MLIR Authors. // @@ -15,213 +15,55 @@ // limitations under the License. // ============================================================================= // -// This file implements a Module level pass performing interprocedural +// This file implements a Function level pass performing interprocedural // propagation of array shapes through function specialization. // //===----------------------------------------------------------------------===// -#include "toy/Dialect.h" - -#include "mlir/Analysis/Verifier.h" -#include "mlir/Dialect/StandardOps/Ops.h" -#include "mlir/IR/BlockAndValueMapping.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/StandardTypes.h" #include "mlir/Pass/Pass.h" -#include "mlir/Support/LogicalResult.h" -#include "llvm/ADT/DenseSet.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringSet.h" +#include "toy/Dialect.h" +#include "toy/Passes.h" +#include "toy/ShapeInferenceInterface.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Debug.h" -#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" -#include -#define DEBUG_TYPE "toy-shape-inference" +#define DEBUG_TYPE "shape-inference" +using namespace mlir; using namespace toy; -using llvm::MutableArrayRef; -using llvm::SmallVector; -using llvm::SmallVectorImpl; -using llvm::StringRef; -using llvm::Twine; -/// Create mangled name for function specialization. We will simply append the -/// shape of the arguments to the function name. For example calling -/// -/// "toy.generic_call"(%1, %3) {callee: "foo"} -/// : (!toy<"array<2, 3>">, !toy<"array<2, 3>">) -> !toy<"array"> -/// -/// would be mangled foo_2x3_2x3. This mangling isn't robust as the user could -/// have provide a function with a similar name. But we will claim this as a -/// feature: this allow the user to provide custom specialization! -static std::string mangle(StringRef funcName, - MutableArrayRef operands) { - std::string mangledName; - mangledName.reserve(funcName.size() + operands.size() * 6); - mangledName = funcName; - for (auto &operand : operands) { - auto arrayTy = operand.get()->getType().cast(); - mangledName += "_"; - const char *sep = ""; - for (auto dim : arrayTy.getShape()) { - mangledName += (sep + Twine(dim)).str(); - sep = "x"; - } - } - return mangledName; -} +/// Include the auto-generated definitions for the shape inference interfaces. +#include "toy/ShapeInferenceOpInterfaces.cpp.inc" namespace { - -/// The ShapeInferencePass is a ModulePass: it will run on the Module as a -/// whole. MLIR also supports FunctionPass which are restricted to modify a -/// single function at a time. This pass couldn't be a function pass due the -/// nature of its interprocedural transformations. +/// The ShapeInferencePass is a FunctionPass that performs intra-procedural +/// shape inference. /// -/// The algorithm has two levels, first intra-procedurally: +/// Algorithm: /// -/// 1) Build a worklist containing all the operations that are returning -/// a generic Toy array: these are the operations that need shape +/// 1) Build a worklist containing all the operations that return a +/// dynamically shaped tensor: these are the operations that need shape /// inference. /// 2) Iterate on the worklist: /// a) find an operation to process: the next ready operation in the /// worklist has all of its arguments non-generic, /// b) if no operation is found, break out of the loop, /// c) remove the operation from the worklist, -/// d) infer the shape of its output from the arguments type. -/// 3) If the worklist is empty, the algorithm succeeded and we infer the -/// return type for the function from the return operation. +/// d) infer the shape of its output from the argument types. +/// 3) If the worklist is empty, the algorithm succeeded. /// -/// There is a twist though: when a call to a generic function is encountered, -/// shape inference requires the return type of the callee to be inferred first. -/// At this point we need to run specialize the callee by cloning it. Here is -/// the inter-procedural flow: -/// -/// 1) Keep a worklist of function to process. Start with function "main". -/// 2) While the worklist isn't empty: -/// a) Take the last inserted function in the worklist. -/// b) Run the intra-procedural shape inference on this function. -/// c) If the intra-procedural shape inference can't complete, it returns -/// a FuncOp that needs to be inferred first. In this case, queue this -/// new function and continue. Otherwise the inference succeeded and we -/// can pop from the queue. -/// -class ShapeInferencePass : public mlir::ModulePass { +class ShapeInferencePass : public mlir::FunctionPass { public: - // One entry in the inter-procedural worklist. It keeps track of the - // function to process, the mangled name for this specialization, and the - // types of the arguments on which to specialize. - struct FunctionToSpecialize { - mlir::FuncOp function; - std::string mangledName; - SmallVector argumentsType; - }; - - void runOnModule() override { - auto module = getModule(); - mlir::ModuleManager moduleManager(module); - auto main = moduleManager.lookupSymbol("main"); - if (!main) { - emitError(mlir::UnknownLoc::get(module.getContext()), - "shape inference failed: can't find a main function\n"); - signalPassFailure(); - return; - } - - /// Inter-procedural loop, initialize with `main` and iterate till - /// successfully infer the full reachable call-graph from main. - SmallVector worklist; - worklist.push_back({main, "", {}}); - while (!worklist.empty()) { - if (failed(specialize(worklist, moduleManager))) - return; - } - - // Delete any generic function left - // FIXME: we may want this as a separate pass. - for (mlir::FuncOp function : - llvm::make_early_inc_range(module.getOps())) { - if (auto genericAttr = - function.getAttrOfType("toy.generic")) { - if (genericAttr.getValue()) - function.erase(); - } - } - } - - /// Run inference on a function. If a mangledName is provided, we need to - /// specialize the function: to this end clone it first. - mlir::LogicalResult - specialize(SmallVectorImpl &funcWorklist, - mlir::ModuleManager &moduleManager) { - FunctionToSpecialize &functionToSpecialize = funcWorklist.back(); - mlir::FuncOp f = functionToSpecialize.function; - - // Check if cloning for specialization is needed (usually anything but main) - // We will create a new function with the concrete types for the parameters - // and clone the body into it. - if (!functionToSpecialize.mangledName.empty()) { - if (moduleManager.lookupSymbol( - functionToSpecialize.mangledName)) { - funcWorklist.pop_back(); - // FuncOp already specialized, move on. - return mlir::success(); - } - // Create a new function with a generic array return type, it will be - // updated when the inference for the function body completes. - auto type = mlir::FunctionType::get(functionToSpecialize.argumentsType, - {ToyArrayType::get(&getContext())}, - &getContext()); - auto newFunction = - mlir::FuncOp::create(f.getLoc(), functionToSpecialize.mangledName, - type, f.getDialectAttrs()); - moduleManager.insert(newFunction); - - // Clone the function body - mlir::BlockAndValueMapping mapper; - f.cloneInto(newFunction, mapper); - LLVM_DEBUG({ - llvm::dbgs() << "====== Cloned : \n"; - f.dump(); - llvm::dbgs() << "====== Into : \n"; - newFunction.dump(); - }); - f = newFunction; - f.setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext())); - // Remap the entry-block arguments - // FIXME: this seems like a bug in `cloneInto()` above? - auto &entryBlock = f.getBlocks().front(); - int blockArgSize = entryBlock.getArguments().size(); - assert(blockArgSize == static_cast(f.getType().getInputs().size())); - entryBlock.addArguments(f.getType().getInputs()); - auto argList = entryBlock.getArguments(); - for (int argNum = 0; argNum < blockArgSize; ++argNum) { - argList[0]->replaceAllUsesWith(argList[blockArgSize]); - entryBlock.eraseArgument(0); - } - assert(succeeded(verify(f))); - } - LLVM_DEBUG(llvm::dbgs() - << "Run shape inference on : '" << f.getName() << "'\n"); - - auto *toyDialect = getContext().getRegisteredDialect("toy"); - if (!toyDialect) { - signalPassFailure(); - return emitError(mlir::UnknownLoc::get(&getContext()), - "Toy dialect is not registered"); - } + void runOnFunction() override { + auto f = getFunction(); // Populate the worklist with the operations that need shape inference: - // these are the Toy operations that return a generic array. + // these are operations that return a dynamic shape. llvm::SmallPtrSet opWorklist; f.walk([&](mlir::Operation *op) { - if (op->getDialect() == toyDialect) { - if (op->getNumResults() == 1 && - op->getResult(0)->getType().cast().isGeneric()) - opWorklist.insert(op); - } + if (returnsDynamicShape(op)) + opWorklist.insert(op); }); // Iterate on the operations in the worklist until all operations have been @@ -229,152 +71,43 @@ public: while (!opWorklist.empty()) { // Find the next operation ready for inference, that is an operation // with all operands already resolved (non-generic). - auto nextop = llvm::find_if(opWorklist, [](mlir::Operation *op) { - return llvm::all_of(op->getOperandTypes(), [](mlir::Type ty) { - return !ty.cast().isGeneric(); - }); - }); + auto nextop = llvm::find_if(opWorklist, returnsDynamicShape); if (nextop == opWorklist.end()) - break; // failure: no operations can be inferred. + break; - mlir::Operation *op = *nextop; + Operation *op = *nextop; opWorklist.erase(op); + + // Ask the operation to infer its output shapes. LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n"); - - // The add operation is trivial: propagate the input type as is. - if (auto addOp = llvm::dyn_cast(op)) { - op->getResult(0)->setType(op->getOperand(0)->getType()); - continue; - } - - // Transpose is easy: just invert the dimensions. - if (op->getName().getStringRef() == "toy.transpose") { - SmallVector dims; - auto arrayTy = op->getOperand(0)->getType().cast(); - dims.insert(dims.end(), arrayTy.getShape().begin(), - arrayTy.getShape().end()); - if (dims.size() == 2) - std::swap(dims[0], dims[1]); - op->getResult(0)->setType(ToyArrayType::get(&getContext(), dims)); - continue; - } - - // Multiplication is a bit trickier, handle rank 1 as dot product and rank - // 2 as matrix multiplications. - // We need to be careful about rank mismatch here: the verifier could - // catch it but shape inference earlier in the pass could generate an - // invalid IR (from an invalid Toy input of course) and we wouldn't want - // to crash here. - if (auto mulOp = llvm::dyn_cast(op)) { - auto lhs = mulOp.getLHS()->getType().cast(); - auto rhs = mulOp.getRHS()->getType().cast(); - auto lhsRank = lhs.getShape().size(); - auto rhsRank = rhs.getShape().size(); - if (lhsRank != rhsRank) { - return op->emitError("shape mismatch: LHS and RHS must have the same " - "rank for multiplication, got ") - << lhsRank << " vs " << lhsRank; - } - SmallVector dims; - if (lhsRank == 1) { - // dot product, result shape is <1> - dims.push_back(1); - } else { - if (lhsRank != 2) { - return op->emitError("shape mismatch: expect rank 1 or 2 for mul " - "operands, got ") - << lhsRank; - } - dims.push_back(lhs.getShape()[0]); - dims.push_back(rhs.getShape()[1]); - } - op->getResult(0)->setType(ToyArrayType::get(&getContext(), dims)); - continue; - } - - // Process calls: lookup the callee after mangling the name with the - // argument shapes. If the callee does not exist, we stop the inference - // for this function, queue the callee in the inter-procedural work list, - // and return. The current function stays in the work list and will - // restart after the callee is processed. - if (auto callOp = llvm::dyn_cast(op)) { - auto calleeName = callOp.getCalleeName(); - auto callee = moduleManager.lookupSymbol(calleeName); - if (!callee) { - signalPassFailure(); - return f.emitError("shape inference failed, call to unknown '") - << calleeName << "'"; - } - auto mangledName = mangle(calleeName, op->getOpOperands()); - LLVM_DEBUG(llvm::dbgs() << "Found callee to infer: '" << calleeName - << "', mangled: '" << mangledName << "'\n"); - auto mangledCallee = - moduleManager.lookupSymbol(mangledName); - if (!mangledCallee) { - // Can't find the target, this is where we queue the request for the - // callee and stop the inference for the current function now. - funcWorklist.push_back({callee, std::move(mangledName), - llvm::to_vector<4>(op->getOperandTypes())}); - return mlir::success(); - } - // Found a specialized callee! Let's turn this into a normal call - // operation. - SmallVector operands(op->getOperands()); - mlir::OpBuilder builder(op); - auto newCall = - builder.create(op->getLoc(), mangledCallee, operands); - if (newCall.getNumResults()) { - op->getResult(0)->replaceAllUsesWith(newCall.getResult(0)); - op->erase(); - continue; - } + if (auto shapeOp = dyn_cast(op)) { + shapeOp.inferShapes(); + } else { + op->emitError("unable to infer shape of operation without shape " + "inference interface"); + return signalPassFailure(); } } - // Done with inference on this function, removing it from the worklist. - funcWorklist.pop_back(); - // Mark the function as non-generic now that inference has succeeded - f.setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext())); - // If the operation worklist isn't empty, this indicates a failure. if (!opWorklist.empty()) { + f.emitError("Shape inference failed, ") + << opWorklist.size() << " operations couldn't be inferred\n"; signalPassFailure(); - auto diag = f.emitError("shape inference failed, ") - << opWorklist.size() << " operations couldn't be inferred\n"; - for (auto *ope : opWorklist) - diag << " - " << *ope << "\n"; - return diag; } + } - // Finally, update the return type of the function based on the argument to - // the return operation. - for (auto &block : f.getBlocks()) { - auto ret = llvm::cast(block.getTerminator()); - if (!ret) - continue; - if (ret.getNumOperands() && - f.getType().getResult(0) == ret.getOperand()->getType()) - // type match, we're done - break; - SmallVector retTy; - if (ret.getNumOperands()) - retTy.push_back(ret.getOperand()->getType()); - std::vector argumentsType; - for (auto arg : f.getArguments()) - argumentsType.push_back(arg->getType()); - auto newType = - mlir::FunctionType::get(argumentsType, retTy, &getContext()); - f.setType(newType); - assert(succeeded(verify(f))); - break; - } - return mlir::success(); + /// A utility method that returns if the given operation has a dynamically + /// shaped result. + static bool returnsDynamicShape(Operation *op) { + return llvm::any_of(op->getResultTypes(), [](Type resultType) { + return !resultType.isa(); + }); } }; } // end anonymous namespace -namespace toy { -std::unique_ptr createShapeInferencePass() { +/// Create a Shape Inference pass. +std::unique_ptr mlir::toy::createShapeInferencePass() { return std::make_unique(); } -} // namespace toy diff --git a/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp index 4798ad188d15..47e1abc6c744 100644 --- a/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp @@ -15,24 +15,30 @@ // limitations under the License. // ============================================================================= // -// This file implements a simple combiner for optimizing pattern in the Toy -// dialect. +// This file implements a set of simple combiners for optimizing operations in +// the Toy dialect. // //===----------------------------------------------------------------------===// -#include "toy/Dialect.h" - #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" -#include "mlir/IR/StandardTypes.h" - +#include "toy/Dialect.h" #include - -namespace toy { +using namespace mlir; +using namespace toy; namespace { +/// Include the patterns defined in the Declarative Rewrite framework. +#include "ToyCombine.inc" +} // end anonymous namespace -/// Fold transpose(transpose(x) -> transpose(x) +/// Fold simple cast operations that return the same type as the input. +OpFoldResult CastOp::fold(ArrayRef operands) { + return mlir::impl::foldCastOp(*this); +} + +/// This is an example of a c++ rewrite pattern for the TransposeOp. It +/// optimizes the following scenario: transpose(transpose(x)) -> transpose(x) struct SimplifyRedundantTranspose : public mlir::OpRewritePattern { /// We register this pattern to match every toy.transpose in the IR. /// The "benefit" is used by the framework to order the patterns and process @@ -40,9 +46,9 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern { SimplifyRedundantTranspose(mlir::MLIRContext *context) : OpRewritePattern(context, /*benefit=*/1) {} - /// This method is attempting to match a pattern and rewrite it. The rewriter - /// argument is the orchestrator of the sequence of rewrites. It is expected - /// to interact with it to perform any changes to the IR from here. + /// This method attempts to match a pattern and rewrite it. The rewriter + /// argument is the orchestrator of the sequence of rewrites. The pattern is + /// expected to interact with it to perform any changes to the IR from here. mlir::PatternMatchResult matchAndRewrite(TransposeOp op, mlir::PatternRewriter &rewriter) const override { @@ -55,132 +61,23 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern { if (!transposeInputOp) return matchFailure(); - // Use the rewriter to perform the replacement + // Use the rewriter to perform the replacement. rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp}); return matchSuccess(); } }; -/// Fold reshape(constant(x)) -> constant(x'), with x' being reshaped in place. -struct SimplifyReshapeConstant : public mlir::OpRewritePattern { - using mlir::OpRewritePattern::OpRewritePattern; - - mlir::PatternMatchResult - matchAndRewrite(ReshapeOp reshape, - mlir::PatternRewriter &rewriter) const override { - // Look through the input of the current reshape. - ConstantOp constantOp = llvm::dyn_cast_or_null( - reshape.getOperand()->getDefiningOp()); - - // If the input is defined by another constant, bingo! - if (!constantOp) - return matchFailure(); - - auto reshapeType = reshape.getType().cast(); - if (auto valueAttr = - constantOp.getAttrOfType("value")) { - // FIXME Check matching of element count! - // auto oldType = constantOp.getType(); - auto newType = rewriter.getTensorType( - reshapeType.getShape(), valueAttr.getType().getElementType()); - auto newAttr = valueAttr.reshape(newType); - rewriter.replaceOpWithNewOp(reshape, reshapeType.getShape(), - newAttr); - } else if (auto valueAttr = - constantOp.getAttrOfType("value")) { - // Broadcast - auto dataSize = std::accumulate(reshapeType.getShape().begin(), - reshapeType.getShape().end(), 1, - std::multiplies()); - std::vector data(dataSize, valueAttr); - auto tensorTy = rewriter.getTensorType(reshapeType.getShape(), - reshapeType.getElementType()); - auto newAttr = mlir::DenseElementsAttr::get(tensorTy, data); - rewriter.replaceOpWithNewOp(reshape, reshapeType.getShape(), - newAttr); - } else { - llvm_unreachable("Unsupported Constant format"); - } - return matchSuccess(); - } -}; - -/// Fold reshape(reshape(x)) -> reshape(x) -struct SimplifyReshapeReshape : public mlir::OpRewritePattern { - using mlir::OpRewritePattern::OpRewritePattern; - - mlir::PatternMatchResult - matchAndRewrite(ReshapeOp op, - mlir::PatternRewriter &rewriter) const override { - // Look through the input of the current reshape. - mlir::Value *reshapeInput = op.getOperand(); - - // If the input is defined by another reshape, bingo! - if (!matchPattern(reshapeInput, mlir::m_Op())) - return matchFailure(); - - // Use the rewriter to perform the replacement - rewriter.replaceOp(op, {reshapeInput}); - return matchSuccess(); - } -}; - -/// Fold reshape(x)) -> x, when input type matches output type -struct SimplifyNullReshape : public mlir::OpRewritePattern { - using mlir::OpRewritePattern::OpRewritePattern; - - mlir::PatternMatchResult - matchAndRewrite(ReshapeOp op, - mlir::PatternRewriter &rewriter) const override { - if (op.getOperand()->getType() != op.getType()) - return matchFailure(); - rewriter.replaceOp(op, {op.getOperand()}); - return matchSuccess(); - } -}; - -} // end anonymous namespace. - -// Register our patterns for rewrite by the Canonicalization framework. -void TransposeOp::getCanonicalizationPatterns( - mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) { +/// Register our patterns as "canonicalization" patterns on the TransposeOp so +/// that they can be picked up by the Canonicalization framework. +void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { results.insert(context); } -// Register our patterns for rewrite by the Canonicalization framework. -void ReshapeOp::getCanonicalizationPatterns( - mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) { - results.insert(context); +/// Register our patterns as "canonicalization" patterns on the ReshapeOp so +/// that they can be picked up by the Canonicalization framework. +void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); } - -namespace { - -/// Fold type.cast(x) -> x, when input type matches output type -struct SimplifyIdentityTypeCast : public mlir::OpRewritePattern { - using mlir::OpRewritePattern::OpRewritePattern; - - mlir::PatternMatchResult - matchAndRewrite(TypeCastOp typeCast, - mlir::PatternRewriter &rewriter) const override { - auto resTy = typeCast.getType(); - auto *candidateOp = typeCast.getOperation(); - while (llvm::isa_and_nonnull(candidateOp)) { - if (resTy == candidateOp->getOperand(0)->getType()) { - rewriter.replaceOp(typeCast, {candidateOp->getOperand(0)}); - return matchSuccess(); - } - candidateOp = candidateOp->getOperand(0)->getDefiningOp(); - } - return matchFailure(); - } -}; - -} // end anonymous namespace. - -void TypeCastOp::getCanonicalizationPatterns( - mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) { - results.insert(context); -} - -} // namespace toy diff --git a/mlir/examples/toy/Ch5/mlir/ToyCombine.td b/mlir/examples/toy/Ch5/mlir/ToyCombine.td new file mode 100644 index 000000000000..0a63861fa96f --- /dev/null +++ b/mlir/examples/toy/Ch5/mlir/ToyCombine.td @@ -0,0 +1,73 @@ +//===- ToyCombine.td - Pattern Match Optimizations for Toy -*- tablegen -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// Defines language-specific pattern match optimizations for Toy using +// Declarative Rewrite Rules (DRR) specified using TableGen records. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_COMBINE +#define TOY_COMBINE + +#ifndef OP_BASE +include "toy/Ops.td" +#endif // OP_BASE + +/// Note: The DRR definition used for defining patterns is shown below: +/// +/// class Pattern< +/// dag sourcePattern, list resultPatterns, +/// list additionalConstraints = [], +/// dag benefitsAdded = (addBenefit 0) +/// >; + +//===----------------------------------------------------------------------===// +// Basic Pattern-Match and Rewrite +//===----------------------------------------------------------------------===// + +// Reshape(Reshape(x)) = Reshape(x) +def ReshapeReshapeOptPattern : Pat<(ReshapeOp(ReshapeOp $arg)), + (ReshapeOp $arg)>; + +//===----------------------------------------------------------------------===// +// Pattern-Match and Rewrite using Native Code Call +//===----------------------------------------------------------------------===// + +// Native Code Calls may be used for more complex transformations using inline +// C++ and C++ helper functions. + +// Reshape(Constant(x)) = x' +def ReshapeConstant : + NativeCodeCall<"$0.reshape(($1->getType()).cast())">; +def FoldConstantReshapeOptPattern : Pat< + (ReshapeOp:$res (ConstantOp $arg)), + (ConstantOp (ReshapeConstant $arg, $res))>; + +//===----------------------------------------------------------------------===// +// Pattern-Match and Rewrite with Constraints +//===----------------------------------------------------------------------===// + +// DRR allows for constraint checking when the transformation is conditional +// on operand properties. + +// Reshape(x) = x, where input and output shapes are identical +def TypesAreIdentical : ConstraintgetType() == $1->getType()">>; +def RedundantReshapeOptPattern : Pat< + (ReshapeOp:$res $arg), (replaceWithValue $arg), + [(TypesAreIdentical $res, $arg)]>; + +#endif // TOY_COMBINE diff --git a/mlir/examples/toy/Ch5/mlir/ToyDialect.cpp b/mlir/examples/toy/Ch5/mlir/ToyDialect.cpp deleted file mode 100644 index 68a48ee01790..000000000000 --- a/mlir/examples/toy/Ch5/mlir/ToyDialect.cpp +++ /dev/null @@ -1,403 +0,0 @@ -//===- ToyDialect.cpp - Toy IR Dialect registration in MLIR ---------------===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// This file implements the dialect for the Toy IR: custom type parsing and -// operation verification. -// -//===----------------------------------------------------------------------===// - -#include "toy/Dialect.h" - -#include "mlir/IR/Builders.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/Support/STLExtras.h" -#include "llvm/ADT/iterator_range.h" -#include "llvm/Support/ErrorHandling.h" -#include "llvm/Support/Regex.h" -#include "llvm/Support/raw_ostream.h" - -using llvm::ArrayRef; -using llvm::raw_ostream; -using llvm::raw_string_ostream; -using llvm::SmallVector; -using llvm::StringRef; -using llvm::Twine; - -namespace toy { -namespace detail { - -/// This class holds the implementation of the ToyArrayType. -/// It is intended to be uniqued based on its content and owned by the context. -struct ToyArrayTypeStorage : public mlir::TypeStorage { - /// This defines how we unique this type in the context: our key contains - /// only the shape, a more complex type would have multiple entries in the - /// tuple here. - /// The element of the tuples usually matches 1-1 the arguments from the - /// public `get()` method arguments from the facade. - using KeyTy = std::tuple>; - static unsigned hashKey(const KeyTy &key) { - return llvm::hash_combine(std::get<0>(key)); - } - /// When the key hash hits an existing type, we compare the shape themselves - /// to confirm we have the right type. - bool operator==(const KeyTy &key) const { return key == KeyTy(getShape()); } - - /// This is a factory method to create our type storage. It is only - /// invoked after looking up the type in the context using the key and not - /// finding it. - static ToyArrayTypeStorage *construct(mlir::TypeStorageAllocator &allocator, - const KeyTy &key) { - // Copy the shape array into the bumpptr allocator owned by the context. - ArrayRef shape = allocator.copyInto(std::get<0>(key)); - - // Allocate the instance for the ToyArrayTypeStorage itself - auto *storage = allocator.allocate(); - // Initialize the instance using placement new. - return new (storage) ToyArrayTypeStorage(shape); - } - - ArrayRef getShape() const { return shape; } - -private: - ArrayRef shape; - - /// Constructor is only invoked from the `construct()` method above. - ToyArrayTypeStorage(ArrayRef shape) : shape(shape) {} -}; - -} // namespace detail - -mlir::Type ToyArrayType::getElementType() { - return mlir::FloatType::getF64(getContext()); -} - -ToyArrayType ToyArrayType::get(mlir::MLIRContext *context, - ArrayRef shape) { - return Base::get(context, ToyTypeKind::TOY_ARRAY, shape); -} - -ArrayRef ToyArrayType::getShape() { return getImpl()->getShape(); } - -mlir::MemRefType ToyArrayType::toMemref() { - auto memRefType = mlir::MemRefType::get(getShape(), getElementType(), {}, 0); - return memRefType; -} - -/// Dialect creation, the instance will be owned by the context. This is the -/// point of registration of custom types and operations for the dialect. -ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) { - addOperations(); - addTypes(); -} - -/// Parse a type registered to this dialect, we expect only Toy arrays. -mlir::Type ToyDialect::parseType(StringRef tyData, mlir::Location loc) const { - // Sanity check: we only support array or array<...> - if (!tyData.startswith("array")) { - emitError(loc, "invalid Toy type '" + tyData + "', array expected"); - return nullptr; - } - // Drop the "array" prefix from the type name, we expect either an empty - // string or just the shape. - tyData = tyData.drop_front(StringRef("array").size()); - // This is the generic array case without shape, early return it. - if (tyData.empty()) - return ToyArrayType::get(getContext()); - - // Use a regex to parse the shape (for efficient we should store this regex in - // the dialect itself). - SmallVector matches; - auto shapeRegex = llvm::Regex("^<([0-9]+)(, ([0-9]+))*>$"); - if (!shapeRegex.match(tyData, &matches)) { - emitError(loc, "invalid toy array shape '" + tyData + "'"); - return nullptr; - } - SmallVector shape; - // Iterate through the captures, skip the first one which is the full string. - for (auto dimStr : - llvm::make_range(std::next(matches.begin()), matches.end())) { - if (dimStr.startswith(",")) - continue; // POSIX misses non-capturing groups. - if (dimStr.empty()) - continue; // '*' makes it an optional group capture - // Convert the capture to an integer - unsigned long long dim; - if (getAsUnsignedInteger(dimStr, /* Radix = */ 10, dim)) { - emitError(loc, "couldn't parse dimension as integer, matched: " + dimStr); - return mlir::Type(); - } - shape.push_back(dim); - } - // Finally we collected all the dimensions in the shape, - // create the array type. - return ToyArrayType::get(getContext(), shape); -} - -/// Print a Toy array type, for example `array<2, 3, 4>` -void ToyDialect::printType(mlir::Type type, raw_ostream &os) const { - auto arrayTy = type.dyn_cast(); - if (!arrayTy) { - os << "unknown toy type"; - return; - } - os << "array"; - if (!arrayTy.getShape().empty()) { - os << "<"; - mlir::interleaveComma(arrayTy.getShape(), os); - os << ">"; - } -} - -//////////////////////////////////////////////////////////////////////////////// -//////////////////// Custom Operations for the Dialect ///////////////////////// -//////////////////////////////////////////////////////////////////////////////// - -/// Helper to verify that the result of an operation is a Toy array type. -template static mlir::LogicalResult verifyToyReturnArray(T *op) { - if (!op->getResult()->getType().template isa()) { - std::string msg; - raw_string_ostream os(msg); - os << "expects a Toy Array for its argument, got " - << op->getResult()->getType(); - return op->emitOpError(os.str()); - } - return mlir::success(); -} - -/// Helper to verify that the two operands of a binary operation are Toy -/// arrays.. -template static mlir::LogicalResult verifyToyBinOperands(T *op) { - if (!op->getOperand(0)->getType().template isa()) { - std::string msg; - raw_string_ostream os(msg); - os << "expects a Toy Array for its LHS, got " - << op->getOperand(0)->getType(); - return op->emitOpError(os.str()); - } - if (!op->getOperand(1)->getType().template isa()) { - std::string msg; - raw_string_ostream os(msg); - os << "expects a Toy Array for its LHS, got " - << op->getOperand(0)->getType(); - return op->emitOpError(os.str()); - } - return mlir::success(); -} - -/// Build a constant operation. -/// The builder is passed as an argument, so is the state that this method is -/// expected to fill in order to build the operation. -void ConstantOp::build(mlir::Builder *builder, mlir::OperationState &state, - ArrayRef shape, mlir::DenseElementsAttr value) { - state.types.push_back(ToyArrayType::get(builder->getContext(), shape)); - auto dataAttribute = builder->getNamedAttr("value", value); - state.attributes.push_back(dataAttribute); -} - -/// Build a constant operation. -/// The builder is passed as an argument, so is the state that this method is -/// expected to fill in order to build the operation. -void ConstantOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::FloatAttr value) { - // Broadcast and forward to the other build factory - mlir::Type elementType = mlir::FloatType::getF64(builder->getContext()); - auto dataType = builder->getTensorType({1}, elementType); - auto dataAttribute = builder->getDenseElementsAttr(dataType, {value}) - .cast(); - - ConstantOp::build(builder, state, {1}, dataAttribute); -} - -/// Verifier for constant operation. -mlir::LogicalResult ConstantOp::verify() { - // Ensure that the return type is a Toy array - if (failed(verifyToyReturnArray(this))) - return mlir::failure(); - - // We expect the constant itself to be stored as an attribute. - auto dataAttr = getAttr("value").dyn_cast(); - if (!dataAttr) { - return emitOpError( - "missing valid `value` DenseElementsAttribute on toy.constant()"); - } - auto attrType = dataAttr.getType().dyn_cast(); - if (!attrType) { - return emitOpError( - "missing valid `value` DenseElementsAttribute on toy.constant()"); - } - - // If the return type of the constant is not a generic array, the shape must - // match the shape of the attribute holding the data. - auto resultType = getResult()->getType().cast(); - if (!resultType.isGeneric()) { - if (attrType.getRank() != resultType.getRank()) { - return emitOpError("The rank of the toy.constant return type must match " - "the one of the attached value attribute: " + - Twine(attrType.getRank()) + - " != " + Twine(resultType.getRank())); - } - for (int dim = 0; dim < attrType.getRank(); ++dim) { - if (attrType.getShape()[dim] != resultType.getShape()[dim]) { - std::string msg; - raw_string_ostream os(msg); - return emitOpError( - "Shape mismatch between toy.constant return type and its " - "attribute at dimension " + - Twine(dim) + ": " + Twine(attrType.getShape()[dim]) + - " != " + Twine(resultType.getShape()[dim])); - } - } - } - return mlir::success(); -} - -void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState &state, - StringRef callee, ArrayRef arguments) { - // Generic call always returns a generic ToyArray initially - state.types.push_back(ToyArrayType::get(builder->getContext())); - state.operands.assign(arguments.begin(), arguments.end()); - auto calleeAttr = builder->getStringAttr(callee); - state.attributes.push_back(builder->getNamedAttr("callee", calleeAttr)); -} - -mlir::LogicalResult GenericCallOp::verify() { - // Verify that every operand is a Toy Array - for (int opId = 0, num = getNumOperands(); opId < num; ++opId) { - if (!getOperand(opId)->getType().template isa()) { - std::string msg; - raw_string_ostream os(msg); - os << "expects a Toy Array for its " << opId << " operand, got " - << getOperand(opId)->getType(); - return emitOpError(os.str()); - } - } - return mlir::success(); -} - -/// Return the name of the callee. -StringRef GenericCallOp::getCalleeName() { - return getAttr("callee").cast().getValue(); -} - -template static mlir::LogicalResult verifyToySingleOperand(T *op) { - if (!op->getOperand()->getType().template isa()) { - std::string msg; - raw_string_ostream os(msg); - os << "expects a Toy Array for its argument, got " - << op->getOperand()->getType(); - return op->emitOpError(os.str()); - } - return mlir::success(); -} - -void ReturnOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *value) { - // Return does not return any value and has an optional single argument - if (value) - state.operands.push_back(value); -} - -mlir::LogicalResult ReturnOp::verify() { - if (getNumOperands() > 1) - return emitOpError("expects zero or one operand, got " + - Twine(getNumOperands())); - if (hasOperand() && failed(verifyToySingleOperand(this))) - return mlir::failure(); - return mlir::success(); -} - -void PrintOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *value) { - // Print does not return any value and has a single argument - state.operands.push_back(value); -} - -mlir::LogicalResult PrintOp::verify() { - if (failed(verifyToySingleOperand(this))) - return mlir::failure(); - return mlir::success(); -} - -void TransposeOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *value) { - state.types.push_back(ToyArrayType::get(builder->getContext())); - state.operands.push_back(value); -} - -mlir::LogicalResult TransposeOp::verify() { - if (failed(verifyToySingleOperand(this))) - return mlir::failure(); - return mlir::success(); -} - -void ReshapeOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *value, ToyArrayType reshapedType) { - state.types.push_back(reshapedType); - state.operands.push_back(value); -} - -mlir::LogicalResult ReshapeOp::verify() { - if (failed(verifyToySingleOperand(this))) - return mlir::failure(); - auto retTy = getResult()->getType().dyn_cast(); - if (!retTy) - return emitOpError("toy.reshape is expected to produce a Toy array"); - if (retTy.isGeneric()) - return emitOpError("toy.reshape is expected to produce a shaped Toy array, " - "got a generic one."); - return mlir::success(); -} - -void AddOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *lhs, mlir::Value *rhs) { - state.types.push_back(ToyArrayType::get(builder->getContext())); - state.operands.push_back(lhs); - state.operands.push_back(rhs); -} - -mlir::LogicalResult AddOp::verify() { - if (failed(verifyToyBinOperands(this))) - return mlir::failure(); - return mlir::success(); -} - -void MulOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *lhs, mlir::Value *rhs) { - state.types.push_back(ToyArrayType::get(builder->getContext())); - state.operands.push_back(lhs); - state.operands.push_back(rhs); -} - -mlir::LogicalResult MulOp::verify() { - if (failed(verifyToyBinOperands(this))) - return mlir::failure(); - return mlir::success(); -} - -void AllocOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Type retType) { - state.types.push_back(retType); -} - -void TypeCastOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *value, mlir::Type destTy) { - state.operands.push_back(value); - state.types.push_back(destTy); -} - -} // namespace toy diff --git a/mlir/examples/toy/Ch5/toyc.cpp b/mlir/examples/toy/Ch5/toyc.cpp index 6600ff6e5560..2c3875fbcbc5 100644 --- a/mlir/examples/toy/Ch5/toyc.cpp +++ b/mlir/examples/toy/Ch5/toyc.cpp @@ -20,30 +20,23 @@ //===----------------------------------------------------------------------===// #include "toy/Dialect.h" -#include "toy/Lowering.h" #include "toy/MLIRGen.h" #include "toy/Parser.h" #include "toy/Passes.h" -#include "linalg1/Dialect.h" #include "mlir/Analysis/Verifier.h" -#include "mlir/ExecutionEngine/ExecutionEngine.h" -#include "mlir/ExecutionEngine/OptUtils.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" #include "mlir/Parser.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" -#include "mlir/Target/LLVMIR.h" #include "mlir/Transforms/Passes.h" #include "llvm/ADT/StringRef.h" -#include "llvm/IR/Module.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/ErrorOr.h" #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/SourceMgr.h" -#include "llvm/Support/TargetSelect.h" #include "llvm/Support/raw_ostream.h" using namespace toy; @@ -64,28 +57,14 @@ static cl::opt inputType( "load the input file as an MLIR file"))); namespace { -enum Action { - None, - DumpAST, - DumpMLIR, - DumpMLIRLinalg, - DumpLLVMDialect, - DumpLLVMIR, - RunJIT -}; +enum Action { None, DumpAST, DumpMLIR, DumpMLIRAffine }; } static cl::opt emitAction( "emit", cl::desc("Select the kind of output desired"), cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")), cl::values(clEnumValN(DumpMLIR, "mlir", "output the MLIR dump")), - cl::values(clEnumValN(DumpMLIRLinalg, "mlir-linalg", - "output the MLIR dump after linalg lowering")), - cl::values(clEnumValN(DumpLLVMDialect, "llvm-dialect", - "output the LLVM MLIR Dialect dump")), - cl::values(clEnumValN(DumpLLVMIR, "llvm-ir", "output the LLVM IR dump")), - cl::values( - clEnumValN(RunJIT, "jit", - "JIT the code and run it by invoking the main function"))); + cl::values(clEnumValN(DumpMLIRAffine, "mlir-affine", + "output the MLIR dump after affine lowering"))); static cl::opt EnableOpt("opt", cl::desc("Enable optimizations")); @@ -103,174 +82,81 @@ std::unique_ptr parseInputFile(llvm::StringRef filename) { return parser.ParseModule(); } -mlir::LogicalResult optimize(mlir::ModuleOp module) { - mlir::PassManager pm(module.getContext()); - pm.addPass(mlir::createCanonicalizerPass()); - pm.addPass(createShapeInferencePass()); - pm.addPass(mlir::createCanonicalizerPass()); - pm.addPass(mlir::createCSEPass()); - - // Apply any generic pass manager command line options. - applyPassManagerCLOptions(pm); - - return pm.run(module); -} - -mlir::LogicalResult lowerDialect(mlir::ModuleOp module, bool OnlyLinalg) { - mlir::PassManager pm(module.getContext()); - pm.addPass(createEarlyLoweringPass()); - pm.addPass(mlir::createCanonicalizerPass()); - pm.addPass(mlir::createCSEPass()); - if (!OnlyLinalg) { - pm.addPass(createLateLoweringPass()); - pm.addPass(mlir::createCanonicalizerPass()); - pm.addPass(mlir::createCSEPass()); - } - // Apply any generic pass manager command line options. - applyPassManagerCLOptions(pm); - - return pm.run(module); -} - -mlir::OwningModuleRef loadFileAndProcessModule( - mlir::MLIRContext &context, bool EnableLinalgLowering = false, - bool EnableLLVMLowering = false, bool EnableOpt = false) { - - mlir::OwningModuleRef module; - if (inputType == InputType::MLIR || - llvm::StringRef(inputFilename).endswith(".mlir")) { - llvm::ErrorOr> fileOrErr = - llvm::MemoryBuffer::getFileOrSTDIN(inputFilename); - if (std::error_code EC = fileOrErr.getError()) { - llvm::errs() << "Could not open input file: " << EC.message() << "\n"; - return nullptr; - } - llvm::SourceMgr sourceMgr; - sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc()); - module = mlir::parseSourceFile(sourceMgr, &context); - if (!module) { - llvm::errs() << "Error can't load file " << inputFilename << "\n"; - return nullptr; - } - if (failed(mlir::verify(*module))) { - llvm::errs() << "Error verifying MLIR module\n"; - return nullptr; - } - } else { +int loadMLIR(mlir::MLIRContext &context, mlir::OwningModuleRef &module) { + // Handle '.toy' input to the compiler. + if (inputType != InputType::MLIR && + !llvm::StringRef(inputFilename).endswith(".mlir")) { auto moduleAST = parseInputFile(inputFilename); module = mlirGen(context, *moduleAST); + return !module ? 1 : 0; } - if (!module) - return nullptr; - if (EnableOpt) { - if (failed(optimize(*module))) { - llvm::errs() << "Module optimization failed\n"; - return nullptr; - } + + // Otherwise, the input is '.mlir'. + llvm::ErrorOr> fileOrErr = + llvm::MemoryBuffer::getFileOrSTDIN(inputFilename); + if (std::error_code EC = fileOrErr.getError()) { + llvm::errs() << "Could not open input file: " << EC.message() << "\n"; + return -1; } - if (EnableLLVMLowering || EnableLinalgLowering) { - if (failed(lowerDialect(*module, !EnableLLVMLowering))) { - llvm::errs() << "Module lowering failed\n"; - return nullptr; - } + + // Parse the input mlir. + llvm::SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc()); + module = mlir::parseSourceFile(sourceMgr, &context); + if (!module) { + llvm::errs() << "Error can't load file " << inputFilename << "\n"; + return 3; } - return module; + return 0; } int dumpMLIR() { + // Register our Dialect with MLIR. + mlir::registerDialect(); + mlir::MLIRContext context; - auto module = - loadFileAndProcessModule(context, /*EnableLinalgLowering=*/false, - /*EnableLLVMLowering=*/false, EnableOpt); - if (!module) - return -1; + mlir::OwningModuleRef module; + if (int error = loadMLIR(context, module)) + return error; + + mlir::PassManager pm(&context); + // Apply any generic pass manager command line options and run the pipeline. + applyPassManagerCLOptions(pm); + + // Check to see what granularity of MLIR we are compiling to. + bool isLoweringToAffine = emitAction >= Action::DumpMLIRAffine; + + if (EnableOpt || isLoweringToAffine) { + // Inline all functions into main and then delete them. + pm.addPass(mlir::createInlinerPass()); + pm.addPass(mlir::toy::createDeadFunctionEliminationPass()); + + // Now that there is only one function, we can infer the shapes of each of + // the operations. + pm.addPass(mlir::toy::createShapeInferencePass()); + pm.addPass(mlir::createCanonicalizerPass()); + } + + if (isLoweringToAffine) { + // Partially lower the toy dialect with a few cleanups afterwards. + pm.addPass(mlir::toy::createLowerToAffinePass()); + pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(mlir::createCSEPass()); + + // Add optimizations if enabled. + if (EnableOpt) { + pm.addPass(mlir::createLoopFusionPass()); + pm.addPass(mlir::createMemRefDataFlowOptPass()); + } + } + + if (mlir::failed(pm.run(*module))) + return 4; + module->dump(); return 0; } -int dumpMLIRLinalg() { - mlir::MLIRContext context; - auto module = loadFileAndProcessModule(context, /*EnableLinalgLowering=*/true, - /*EnableLLVMLowering=*/false, - /* EnableOpt=*/true); - if (!module) - return -1; - module->dump(); - return 0; -} - -int dumpLLVMDialect() { - mlir::MLIRContext context; - auto module = loadFileAndProcessModule( - context, /*EnableLinalgLowering=*/false, /* EnableLLVMLowering=*/true, - /* EnableOpt=*/true); - if (!module) { - llvm::errs() << "Failed to load/lower MLIR module\n"; - return -1; - } - module->dump(); - return 0; -} - -int dumpLLVMIR() { - mlir::MLIRContext context; - auto module = loadFileAndProcessModule( - context, /*EnableLinalgLowering=*/false, /* EnableLLVMLowering=*/true, - /* EnableOpt=*/true); - if (!module) { - llvm::errs() << "Failed to load/lower MLIR module\n"; - return -1; - } - auto llvmModule = translateModuleToLLVMIR(*module); - if (!llvmModule) { - llvm::errs() << "Failed to emit LLVM IR\n"; - return -1; - } - // Initialize LLVM targets. - llvm::InitializeNativeTarget(); - llvm::InitializeNativeTargetAsmPrinter(); - mlir::ExecutionEngine::setupTargetTriple(llvmModule.get()); - auto optPipeline = mlir::makeOptimizingTransformer( - /* optLevel=*/EnableOpt ? 3 : 0, /* sizeLevel=*/0, - /* targetMachine=*/nullptr); - if (auto err = optPipeline(llvmModule.get())) { - llvm::errs() << "Failed to optimize LLVM IR " << err << "\n"; - return -1; - } - llvm::errs() << *llvmModule << "\n"; - return 0; -} - -int runJit() { - mlir::MLIRContext context; - auto module = loadFileAndProcessModule( - context, /*EnableLinalgLowering=*/false, /* EnableLLVMLowering=*/true, - /* EnableOpt=*/true); - - // Initialize LLVM targets. - llvm::InitializeNativeTarget(); - llvm::InitializeNativeTargetAsmPrinter(); - - // Create an MLIR execution engine. The execution engine eagerly JIT-compiles - // the module. - auto optPipeline = mlir::makeOptimizingTransformer( - /* optLevel=*/EnableOpt ? 3 : 0, /* sizeLevel=*/0, - /* targetMachine=*/nullptr); - auto maybeEngine = mlir::ExecutionEngine::create(*module, optPipeline); - assert(maybeEngine && "failed to construct an execution engine"); - auto &engine = maybeEngine.get(); - - // Invoke the JIT-compiled function with the arguments. Note that, for API - // uniformity reasons, it takes a list of type-erased pointers to arguments. - auto invocationResult = engine->invoke("main"); - if (invocationResult) { - llvm::errs() << "JIT invocation failed\n"; - return -1; - } - - return 0; -} - int dumpAST() { if (inputType == InputType::MLIR) { llvm::errs() << "Can't dump a Toy AST when the input is MLIR\n"; @@ -286,10 +172,6 @@ int dumpAST() { } int main(int argc, char **argv) { - // Register our Dialects with MLIR - mlir::registerDialect(); - mlir::registerDialect(); - mlir::registerPassManagerCLOptions(); cl::ParseCommandLineOptions(argc, argv, "toy compiler\n"); @@ -297,18 +179,10 @@ int main(int argc, char **argv) { case Action::DumpAST: return dumpAST(); case Action::DumpMLIR: + case Action::DumpMLIRAffine: return dumpMLIR(); - case Action::DumpMLIRLinalg: - return dumpMLIRLinalg(); - case Action::DumpLLVMDialect: - return dumpLLVMDialect(); - case Action::DumpLLVMIR: - return dumpLLVMIR(); - case Action::RunJIT: - return runJit(); default: llvm::errs() << "No action specified (parsing only?), use -emit=\n"; - return -1; } return 0; diff --git a/mlir/g3doc/Tutorials/Toy/Ch-5.md b/mlir/g3doc/Tutorials/Toy/Ch-5.md index 11373876a5b3..841e13dea0b7 100644 --- a/mlir/g3doc/Tutorials/Toy/Ch-5.md +++ b/mlir/g3doc/Tutorials/Toy/Ch-5.md @@ -1,4 +1,4 @@ -# Chapter 5: CodeGen via Lowering to Lower-Level Dialects +# Chapter 5 - Partial Lowering to Lower-Level Dialects for Optimization At this point, we are eager to generate actual code and see our Toy language taking life. We will obviously use LLVM to generate code, but just showing the @@ -6,293 +6,356 @@ LLVM builder interface wouldn't be very exciting here. Instead, we will show how to perform progressive lowering through a mix of dialects coexisting in the same function. -To make it more interesting, we will consider that we want to reuse existing -optimizations implemented in a dialect optimizing linear algebra: `Linalg`. This -dialect is tailored to the computation heavy part of the program, and is -limited: it doesn't support representing our `toy.print` builtin for instance, -neither should it! Instead we can target `Linalg` for the computation heavy part -of Toy (mostly matmul), we will target the `Affine` dialect for other -well-formed loop nest, and directly the `LLVM IR` dialect for lowering `print`. +To make it more interesting, in this chapter we will will consider that we want +to reuse existing optimizations implemented in a dialect optimizing affine +transformations: `Affine`. This dialect is tailored to the computation-heavy +part of the program, and is limited: it doesn't support representing our +`toy.print` builtin for instance, neither should it! Instead we can target +`Affine` for the computation heavy part of Toy, and in the +[next chapter](Ch-6.md) directly the `LLVM IR` dialect for lowering `print`. As +part of this lowering, we will be lowering from the +[TensorType](../../LangRef.md#tensor-type), that `Toy` operates on, to the +[MemRefType](../../LangRef.md#memref-type) that is indexed via an affine +loop-nest. Tensors represent an abstract value-typed sequence of data, meaning +that they don't live in any memory. MemRefs on the other hand represent lower +level buffer access, as they are concrete references to a region of memory. -# The `DialectConversion` Framework +# Dialect Conversions -Similarly to the canonicalization patterns introduced in the previous section, -the `DialectConversion` framework involves its own set of patterns. This -framework operates a bit differently from the canonicalizer: a new function is -created and the pattern matching operation in the original function are expected -to emit the IR in the new function. +MLIR contains many different dialects, so it is important to have a unified +framework for converting between them. This is where the `DialectConversion` +framework comes into play. This framework allows for transforming a set of +`illegal` operations to a set of `legal` ones. To use this framework we need to +provide two things: -Dialect conversion requires three components, implemented by overriding virtual -methods defined in `DialectConversion`: +* A [Conversion Target](../../DialectConversion.md#conversion-target) -- Type Conversion: for things like block arguments' type. -- Function signature conversion: for every function it is invoked with the - function type and the conversion generates a new prototype for the converted - function. The default implementation will call into the type conversion for - the returned values and for each of the parameters. -- Operations conversions: each pattern is expected to generate new results - matching the current operations' in the new function. This may involve - generating one or multiple new operations, or possibly just remapping - existing operands (folding). + - This is the formal specification of what operations, or dialects, are + legal for the conversion. Operations that aren't legal will require + rewrite patterns to perform legalization. -A typical starting point for implementing our lowering would be: +* A set of + [Rewrite Patterns](../../DialectConversion.md#rewrite-pattern-specification) + + - These are the set of [patterns](../../QuickstartRewrites.md) used to + convert `illegal` operations into a set of zero or more `legal` ones. + +* Optionally, A [Type Converter](../../DialectConversion.md#type-conversion). + + - If provided, this is used to convert the types of block arguments. We + won't be needing this for our conversion. + +## Conversion Target + +For our purposes, we want to convert the compute intensive `Toy` operations into +a combination of operations from the `Affine` `Standard` dialects for further +optimization. To start off the lowering, we first define our conversion target: ```c++ -class Lowering : public DialectConversion { -public: - // This gets called for block and region arguments, and attributes. - Type convertType(Type t) override { /*...*/ } +void ToyToAffineLoweringPass::runOnFunction() { + // The first thing to define is the conversion target. This will define the + // final target for this lowering. + mlir::ConversionTarget target(getContext()); - // This gets called for functions. - FunctionType convertFunctionSignatureType(FunctionType type, - ArrayRef argAttrs, - SmallVectorImpl &convertedArgAttrs) { /*...*/ } + // We define the specific operations, or dialects, that are legal targets for + // this lowering. In our case, we are lowering to a combination of the + // `Affine` and `Standard` dialects. + target.addLegalDialect(); - // This gets called once to set up operation converters. - llvm::DenseSet - initConverters(MLIRContext *context) override { - RewriteListBuilder::build(allocator, context); + // We also define the Toy dialect as Illegal so that the conversion will fail + // if any of these operations are *not* converted. Given that we actually want + // a partial lowering, we explicitly mark the Toy operations that don't want + // to lower, `toy.print`, as `legal`. + target.addIllegalDialect(); + target.addLegalOp(); + ... +} +``` + +## Conversion Patterns + +After the conversion target has been defined, we can define how to convert the +`illegal` operations into `legal` ones. Similarly to the canonicalization +framework introduced in [chapter 3](Ch-3.md), the +[`DialectConversion` framework](../../DialectConversion.md) also uses +[RewritePatterns](../../QuickstartRewrites.md) to perform the conversion logic. +These patterns may be the `RewritePatterns` seen before, or a new type of +pattern specific to the conversion framework `ConversionPattern`. +`ConversionPatterns` are different from traditional `RewritePatterns` in that +they accept an additional `operands` parameter containing operands that have +been remapped/replaced. This is used when dealing with type conversions as the +pattern will want to operand on values of the new type, but match against the +old. For our lowering, this invariant will be useful during our lowering as we +will be translating from the [TensorType](../../LangRef.md#tensor-type), +currently being operated on, to the [MemRefType](../../LangRef.md#memref-type). +Let's look at a snippet of lowering the `toy.transpose` operation: + +```c++ +/// Lower the `toy.transpose` operation to an affine loop nest. +struct TransposeOpLowering : public mlir::ConversionPattern { + TransposeOpLowering(mlir::MLIRContext *ctx) + : mlir::ConversionPattern(TransposeOp::getOperationName(), 1, ctx) {} + + /// Match and rewrite the given `toy.transpose` operation, with the given + /// operands that have been remapped from `tensor<...>` to `memref<...>`. + mlir::PatternMatchResult + matchAndRewrite(mlir::Operation *op, ArrayRef operands, + mlir::ConversionPatternRewriter &rewriter) const final { + auto loc = op->getLoc(); + + // Call to a helper function that will lower the current operation to a set + // of affine loops. We provide a functor that operates on the remapped + // operands, as well as the loop induction variables for the inner most + // loop body. + lowerOpToLoops( + op, operands, rewriter, + [loc](mlir::PatternRewriter &rewriter, + ArrayRef memRefOperands, + ArrayRef loopIvs) { + // Generate an adaptor for the remapped operands of the TransposeOp. + // This allows for using the nice named accessors that are generated + // by the ODS. This adaptor is automatically provided by the ODS + // framework. + TransposeOpOperandAdaptor tranposeAdaptor(memRefOperands); + mlir::Value *input = tranposeAdaptor.input(); + + // Transpose the elements by generating a load from the reverse + // indices. + SmallVector reverseIvs(llvm::reverse(loopIvs)); + return rewriter.create(loc, input, reverseIvs); + }); + return matchSuccess(); } - -private: - llvm::BumpPtrAllocator allocator; }; ``` -Individual operation converters are following this pattern: +Now we can prepare the list of patterns to use during the lowering process: ```c++ -/// Lower a toy.add to an affine loop nest. -/// -/// This class inherit from `ConversionPattern` and override `rewrite`, -/// similarly to the PatternRewriter introduced in the previous chapter. -/// It will be called by the DialectConversion framework (see `LateLowering` -/// class below). -class AddOpConversion : public ConversionPattern { -public: - explicit AddOpConversion(MLIRContext *context) - : ConversionPattern(toy::AddOp::getOperationName(), 1, context) {} +void ToyToAffineLoweringPass::runOnFunction() { + ... - /// Lower the `op` by generating IR using the `rewriter` builder. The builder - /// is setup with a new function, the `operands` array has been populated with - /// the rewritten operands for `op` in the new function. - /// The results created by the new IR with the builder are returned, and their - /// number must match the number of result of `op`. - SmallVector rewrite(Operation *op, ArrayRef operands, - OpBuilder &rewriter) const override { - ... + // Now that the conversion target has been defined, we just need to provide + // the set of patterns that will lower the Toy operations. + mlir::OwningRewritePatternList patterns; + patterns.insert<..., TransposeOpLowering>(&getContext()); - // Return the newly allocated buffer, it will be used as an operand when - // converting the operations corresponding to the users of this `toy.add`. - return result; - } + ... ``` -## Linalg +## Partial Lowering -Linalg is an advanced dialect for dense algebra optimizations. It is implemented -as [a separate tutorial](../Linalg/Ch-1.md) in parallel with Toy. We are acting -as a user of this dialect by lowering Toy matrix multiplications to -`linalg.matmul`. +Once the patterns have been defined, we can perform the actual lowering. The +`DialectConversion` framework provides several different modes of lowering, but +for our purposes we will be performing a partial lowering, as we will not be +converting `toy.print` at this time. -To support this, we will split our lowering in two parts: an *early lowering* -that emits operations in the `Linalg` dialect for a subset of the Toy IR, and a -*late lowering* that materializes buffers and converts all operations and type -to the LLVM dialect. We will then be able to run specific optimizations in -between the two lowering. +```c++ +void ToyToAffineLoweringPass::runOnFunction() { + // The first thing to define is the conversion target. This will define the + // final target for this lowering. + mlir::ConversionTarget target(getContext()); -Let's look again at our example `multiply_transpose`: + // We define the specific operations, or dialects, that are legal targets for + // this lowering. In our case, we are lowering to a combination of the + // `Affine` and `Standard` dialects. + target.addLegalDialect(); -```mlir -func @multiply_transpose(%arg0: !toy.array, %arg1: !toy.array) - attributes {toy.generic: true} { - %0 = "toy.transpose"(%arg1) : (!toy.array) -> !toy.array - %1 = "toy.mul"(%arg0, %0) : (!toy.array, !toy.array) -> !toy.array - "toy.return"(%1) : (!toy.array) -> () + // We also define the Toy dialect as Illegal so that the conversion will fail + // if any of these operations are *not* converted. Given that we actually want + // a partial lowering, we explicitly mark the Toy operations that don't want + // to lower, `toy.print`, as `legal`. + target.addIllegalDialect(); + target.addLegalOp(); + + // Now that the conversion target has been defined, we just need to provide + // the set of patterns that will lower the Toy operations. + mlir::OwningRewritePatternList patterns; + patterns.insert<..., TransposeOpLowering>(&getContext()); + + // With the target and rewrite patterns defined, we can now attempt the + // conversion. The conversion will signal failure if any of our `illegal` + // operations were not converted successfully. + auto function = getFunction(); + if (mlir::failed(mlir::applyPartialConversion(function, target, patterns))) + signalPassFailure(); } ``` -After shape inference, and lowering to `Linalg`, here is what our IR will look -like: +### Design Considerations With Partial Lowering + +Before diving into the result of our lowering, this is a good time to discuss +potential design considerations when it comes to partial lowering. In our +lowering, we will be transforming from a value-type, TensorType, to a +allocated(buffer-like) type, MemRefType. Given that we will not be lowering the +`toy.print` operation, we need to temporarily bridge these two worlds. There are +many ways to go about this, each with their own tradeoffs: + +* Generate `load` operations from the buffer + +One option is to generate `load` operations from the buffer type to materialize +an instance of the value type. This allows for the definition of the `toy.print` +operation to remain unchanged. The downside to this approach is that the +optimizations on the `affine` dialect are limited, because the `load` will +actually involve a full copy that is only visible *after* our optimizations have +been performed. + +* Generate a new version of `toy.print` that operates on the lowered type + +Another option would be to have another, lowered, variant of `toy.print` that +operates on the lowered type. The benefit of this option is that there is no +hidden, unnecessary, copy to optimizer. The downside is that another operation +definition is needed, that may duplicate many aspects of the first. Defining a +base class in [ODS](../../OpDefinitions.md) may simplify this, but you still +need to treat these operations separately. + +* Update `toy.print` to allow for operating on the lowered type + +A third option is to update the current definition of `toy.print` to allow for +operating the on the lowered type. The benefit of this approach is that it is +simple, does not introduce an additional hidden copy, and does not require +another operation definition. The downside to this option is that it requires +mixing abstraction levels in the `Toy` dialect. + +For the sake of simplicity, we will use the third option for this lowering. This +involves updating the type constraints on the PrintOp in the operation +definition file: + +```tablegen +def PrintOp : Toy_Op<"print"> { + ... + + // The print operation takes an input tensor to print. + // We also allow a F64MemRef to enable interop during partial lowering. + let arguments = (ins AnyTypeOf<[F64Tensor, F64MemRef]>:$input); +} +``` + +## Complete Toy Example + +Looking back at our current working example: + +```.mlir +func @main() { + %0 = "toy.constant"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64> + %2 = "toy.transpose"(%0) : (tensor<2x3xf64>) -> tensor<2x3xf64> + %3 = "toy.mul"(%2, %2) : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<2x3xf64> + "toy.print"(%3) : (tensor<2x3xf64>) -> () + "toy.return"() : () -> () +} +``` + +With affine lowering added to our pipeline, we can now generate: ```mlir -func @multiply_transpose_2x3_2x3(%arg0: !toy.array<2, 3>, %arg1: !toy.array<2, 3>) -> !toy.array<2, 2> - attributes {toy.generic: false} { - %c3 = constant 3 : index +func @main() { %c0 = constant 0 : index - %c2 = constant 2 : index %c1 = constant 1 : index - %0 = "toy.transpose"(%arg1) : (!toy.array<2, 3>) -> !toy.array<3, 2> - %1 = "toy.alloc"() : () -> !toy.array<2, 2> - %2 = "toy.cast"(%1) : (!toy.array<2, 2>) -> memref<2x2xf64> - %3 = "toy.cast"(%arg0) : (!toy.array<2, 3>) -> memref<2x3xf64> - %4 = "toy.cast"(%0) : (!toy.array<3, 2>) -> memref<3x2xf64> - %5 = linalg.range %c0:%c2:%c1 : !linalg.range - %6 = linalg.range %c0:%c3:%c1 : !linalg.range - %7 = linalg.view %3[%5, %6] : !linalg<"view"> - %8 = linalg.view %4[%6, %5] : !linalg<"view"> - %9 = linalg.view %2[%5, %5] : !linalg<"view"> - linalg.matmul(%7, %8, %9) : !linalg<"view"> - "toy.return"(%1) : (!toy.array<2, 2>) -> () -} -``` + %c2 = constant 2 : index + %cst = constant 1.000000e+00 : f64 + %cst_0 = constant 2.000000e+00 : f64 + %cst_1 = constant 3.000000e+00 : f64 + %cst_2 = constant 4.000000e+00 : f64 + %cst_3 = constant 5.000000e+00 : f64 + %cst_4 = constant 6.000000e+00 : f64 -Note how the operations from multiple dialects are coexisting in this function. + // Allocating buffers for the inputs and outputs. + %0 = alloc() : memref<2x3xf64> + %1 = alloc() : memref<2x3xf64> + %2 = alloc() : memref<2x3xf64> -You can reproduce this result with `bin/toyc-ch5 -test/Examples/Toy/Ch5/lowering.toy -emit=mlir-linalg` + // Initialize the input buffer with the constant values. + affine.store %cst, %2[%c0, %c0] : memref<2x3xf64> + affine.store %cst_0, %2[%c0, %c1] : memref<2x3xf64> + affine.store %cst_1, %2[%c0, %c2] : memref<2x3xf64> + affine.store %cst_2, %2[%c1, %c0] : memref<2x3xf64> + affine.store %cst_3, %2[%c1, %c1] : memref<2x3xf64> + affine.store %cst_4, %2[%c1, %c2] : memref<2x3xf64> -## Emitting LLVM - -The availability of various dialects allows for a smooth lowering by reducing -the impedance mismatch between dialects. For example we don't need to lower our -`toy.print` over array directly to LLVM IR, we can use the well structured loop -from the `Affine` dialect for convenience when scanning the array and insert a -call to `llvm.printf` in the body. We will rely on MLIR lowering to LLVM for the -`Affine` dialect, we get it for free. Here is a simplified version of the code -in this chapter for lowering `toy.print`: - -```c++ - // Create our loop nest now - using namespace edsc; - using llvmCall = intrinsics::ValueBuilder; - ScopedContext scope(rewriter, loc); - ValueHandle zero = intrinsics::constant_index(0); - ValueHandle fmtCst(getConstantCharBuffer(rewriter, loc, "%f ")); - ValueHandle fmtEol(getConstantCharBuffer(rewriter, loc, "\n")); - MemRefView vOp(operand); - IndexedValue iOp(operand); - IndexHandle i, j, M(vOp.ub(0)), N(vOp.ub(1)); - LoopBuilder(&i, zero, M, 1)({ - LoopBuilder(&j, zero, N, 1)({ - llvmCall(retTy, - rewriter.getSymbolRefAttr(printfFunc), - {fmtCst, iOp(i, j)}) - }), - llvmCall(retTy, rewriter.getSymbolRefAttr(printfFunc), {fmtEol}) - }); -``` - -For instance the Toy IR may contain: - -``` - "toy.print"(%0) : (!toy.array<2, 2>) -> () -``` - -which the converter above will turn into this sequence: - -```mlir - affine.for %i0 = 0 to 2 { - affine.for %i1 = 0 to 2 { - %3 = load %0[%i0, %i1] : memref<2x2xf64> - %4 = llvm.call @printf(%1, %3) : (!llvm<"i8*">, !llvm.double) -> !llvm.i32 + // Load the transpose value from the input buffer and store it into the + // next input buffer. + affine.for %arg0 = 0 to 2 { + affine.for %arg1 = 0 to 3 { + %3 = affine.load %2[%arg1, %arg0] : memref<2x3xf64> + affine.store %3, %1[%arg0, %arg1] : memref<2x3xf64> } - %5 = llvm.call @printf(%2, %cst_21) : (!llvm<"i8*">, !llvm.double) -> !llvm.i32 } + + // Multiply and store into the output buffer. + affine.for %arg0 = 0 to 2 { + affine.for %arg1 = 0 to 3 { + %3 = affine.load %1[%arg0, %arg1] : memref<2x3xf64> + %4 = affine.load %1[%arg0, %arg1] : memref<2x3xf64> + %5 = mulf %3, %4 : f64 + affine.store %5, %0[%arg0, %arg1] : memref<2x3xf64> + } + } + + // Print the value held by the buffer. + "toy.print"(%0) : (memref<2x3xf64>) -> () + dealloc %2 : memref<2x3xf64> + dealloc %1 : memref<2x3xf64> + dealloc %0 : memref<2x3xf64> + return +} ``` -Note the mix of a loop nest in the `Affine` dialect, with an operation -`llvm.call` in the body. MLIR knows already how to lower this to: +## Taking Advantage of Affine Optimization + +Our naive lowering is correct, but it leaves a lot to be desired in regards to +efficiency; For example the lowering of `toy.mul` has generated some redundant +loads. Let's look at how adding a few existing optimizations to the pipeline can +help clean this up. Adding the `LoopFusion` and `MemRefDataFlowOpt` passes to +the pipeline gives the following result: ```mlir - llvm.br ^bb1(%87 : !llvm.i64) -^bb1(%89: !llvm.i64): // 2 preds: ^bb0, ^bb5 - %90 = llvm.icmp "slt" %89, %88 : !llvm.i64 - llvm.cond_br %90, ^bb2, ^bb6 -^bb2: // pred: ^bb1 - %91 = llvm.mlir.constant(0 : index) : !llvm.i64 - %92 = llvm.mlir.constant(2 : index) : !llvm.i64 - llvm.br ^bb3(%91 : !llvm.i64) -^bb3(%93: !llvm.i64): // 2 preds: ^bb2, ^bb4 - %94 = llvm.icmp "slt" %93, %92 : !llvm.i64 - llvm.cond_br %94, ^bb4, ^bb5 -^bb4: // pred: ^bb3 - %95 = llvm.mlir.constant(2 : index) : !llvm.i64 - %96 = llvm.mlir.constant(2 : index) : !llvm.i64 - %97 = llvm.mul %89, %96 : !llvm.i64 - %98 = llvm.add %97, %93 : !llvm.i64 - %99 = llvm.getelementptr %6[%98] : (!llvm<"double*">, !llvm.i64) -> !llvm<"double*"> - %100 = llvm.load %99 : !llvm<"double*"> - %101 = llvm.call @printf(%48, %100) : (!llvm<"i8*">, !llvm.double) -> !llvm.i32 - %102 = llvm.mlir.constant(1 : index) : !llvm.i64 - %103 = llvm.add %93, %102 : !llvm.i64 - llvm.br ^bb3(%103 : !llvm.i64) -^bb5: // pred: ^bb3 - %104 = llvm.call @printf(%76, %71) : (!llvm<"i8*">, !llvm.double) -> !llvm.i32 - %105 = llvm.mlir.constant(1 : index) : !llvm.i64 - %106 = llvm.add %89, %105 : !llvm.i64 - llvm.br ^bb1(%106 : !llvm.i64) -``` +func @main() { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + %cst = constant 1.000000e+00 : f64 + %cst_0 = constant 2.000000e+00 : f64 + %cst_1 = constant 3.000000e+00 : f64 + %cst_2 = constant 4.000000e+00 : f64 + %cst_3 = constant 5.000000e+00 : f64 + %cst_4 = constant 6.000000e+00 : f64 -We appreciate the ease to generate the former, as well as the readability! + // Allocating buffers for the inputs and outputs. + %0 = alloc() : memref<2x3xf64> + %1 = alloc() : memref<2x3xf64> -You may reproduce these results with `echo "def main() { print([[1,2],[3,4]]); } -" | bin/toyc-ch5 -x toy - -emit=llvm-dialect` and `echo "def main() { -print([[1,2],[3,4]]); } " | bin/toyc-ch5 -x toy - -emit=llvm-ir`. + // Initialize the input buffer with the constant values. + affine.store %cst, %1[%c0, %c0] : memref<2x3xf64> + affine.store %cst_0, %1[%c0, %c1] : memref<2x3xf64> + affine.store %cst_1, %1[%c0, %c2] : memref<2x3xf64> + affine.store %cst_2, %1[%c1, %c0] : memref<2x3xf64> + affine.store %cst_3, %1[%c1, %c1] : memref<2x3xf64> + affine.store %cst_4, %1[%c1, %c2] : memref<2x3xf64> -# CodeGen: Getting Out of MLIR + affine.for %arg0 = 0 to 2 { + affine.for %arg1 = 0 to 3 { + // Load the transpose value from the input buffer. + %2 = affine.load %1[%arg1, %arg0] : memref<2x3xf64> -At this point, all the IR is expressed in the LLVM dialect, MLIR can perform a -straight conversion to an LLVM module. You may look into -[`Ch5/toyc.cpp`](../../../examples/toy/Ch5/toyc.cpp) for the `dumpLLVM()` -function: - -```c++ -int dumpLLVM() { - mlir::MLIRContext context; - auto module = loadFileAndProcessModule(context, /* EnableLowering=*/ true); - auto llvmModule = translateModuleToLLVMIR(*module); - if (!llvmModule) { - llvm::errs() << "Failed to emit LLVM IR\n"; - return -1; + // Multiply and store into the output buffer. + %3 = mulf %2, %2 : f64 + affine.store %3, %0[%arg0, %arg1] : memref<2x3xf64> + } } - llvm::errs() << *llvmModule << "\n"; - return 0; + + // Print the value held by the buffer. + "toy.print"(%0) : (memref<2x3xf64>) -> () + dealloc %1 : memref<2x3xf64> + dealloc %0 : memref<2x3xf64> + return } ``` -Adding a JIT isn't much more involved either: +Here we can see that an allocation was removed, the two loop nests were fused, +and we also were able to remove an unnecessary allocation! You can build +`toyc-ch5` and try yourself: `toyc-ch5 test/lowering.toy -emit=mlir-affine`. We +can also check our optimizations by adding `-opt`. -```c++ -int runJit() { - mlir::MLIRContext context; - auto module = loadFileAndProcessModule(context, /* EnableLowering=*/ true); - - // Initialize LLVM targets. - llvm::InitializeNativeTarget(); - llvm::InitializeNativeTargetAsmPrinter(); - - // Create an MLIR execution engine. Note that it takes a null pass manager - // to make sure it won't run "default" passes on the MLIR that would trigger - // a second conversion to LLVM IR. The execution engine eagerly JIT-compiles - // the module. - auto maybeEngine = - mlir::ExecutionEngine::create(module.get(), /*pm=*/nullptr); - assert(maybeEngine && "failed to construct an execution engine"); - auto &engine = maybeEngine.get(); - - // Invoke the JIT-compiled function with the arguments. Note that, for API - // uniformity reasons, it takes a list of type-erased pointers to arguments. - auto invocationResult = engine->invoke("main"); - if(invocationResult) { - llvm::errs() << "JIT invocation failed\n"; - return -1; - } - - return 0; -} -``` - -You can play with it, from the build directory: - -```bash -$ echo 'def main() { print([[1, 2], [3, 4]]); }' | ./bin/toyc-ch5 -emit=jit -1.000000 2.000000 -3.000000 4.000000 -``` - -You can also play with `-emit=mlir`, `-emit=mlir-linalg`, `-emit=llvm-dialect`, -and `-emit=llvm-ir` to compare the various level of IR involved. Try also -options like `--print-ir-after-all` to track the evolution of the IR throughout -the pipeline. +In this chapter we explored some aspects of partial lowering, with the intent to +optimize. In the [next chapter](Ch-6.md) we will continue the discussion about +dialect conversion by targeting LLVM for code generation. diff --git a/mlir/test/Examples/Toy/Ch4/ast.toy b/mlir/test/Examples/Toy/Ch4/ast.toy index 9576c9c5ced0..c24b7b94cbc7 100644 --- a/mlir/test/Examples/Toy/Ch4/ast.toy +++ b/mlir/test/Examples/Toy/Ch4/ast.toy @@ -1,7 +1,6 @@ # RUN: toyc-ch4 %s -emit=ast 2>&1 | FileCheck %s - -# User defined generic function that operates solely on +# User defined generic function that operates on unknown shaped arguments. def multiply_transpose(a, b) { return a * transpose(b); } @@ -11,8 +10,10 @@ def main() { # The shape is inferred from the supplied literal. var a = [[1, 2, 3], [4, 5, 6]]; # b is identical to a, the literal array is implicitly reshaped: defining new - # variables is the way to reshape arrays (element count must match). + # variables is the way to reshape arrays (element count in literal must match + # the size of specified shape). var b<2, 3> = [1, 2, 3, 4, 5, 6]; + # This call will specialize `multiply_transpose` with <2, 3> for both # arguments and deduce a return type of <2, 2> in initialization of `c`. var c = multiply_transpose(a, b); @@ -30,44 +31,44 @@ def main() { # CHECK: Module: # CHECK-NEXT: Function -# CHECK-NEXT: Proto 'multiply_transpose' @{{.*}}ast.toy:5:1' +# CHECK-NEXT: Proto 'multiply_transpose' @{{.*}}ast.toy:4:1' # CHECK-NEXT: Params: [a, b] # CHECK-NEXT: Block { # CHECK-NEXT: Retur -# CHECK-NEXT: BinOp: * @{{.*}}ast.toy:6:14 -# CHECK-NEXT: var: a @{{.*}}ast.toy:6:10 -# CHECK-NEXT: Call 'transpose' [ @{{.*}}ast.toy:6:14 -# CHECK-NEXT: var: b @{{.*}}ast.toy:6:24 +# CHECK-NEXT: BinOp: * @{{.*}}ast.toy:5:14 +# CHECK-NEXT: var: a @{{.*}}ast.toy:5:10 +# CHECK-NEXT: Call 'transpose' [ @{{.*}}ast.toy:5:14 +# CHECK-NEXT: var: b @{{.*}}ast.toy:5:24 # CHECK-NEXT: ] # CHECK-NEXT: } // Block # CHECK-NEXT: Function -# CHECK-NEXT: Proto 'main' @{{.*}}ast.toy:9:1' +# CHECK-NEXT: Proto 'main' @{{.*}}ast.toy:8:1' # CHECK-NEXT: Params: [] # CHECK-NEXT: Block { -# CHECK-NEXT: VarDecl a<> @{{.*}}ast.toy:12:3 -# CHECK-NEXT: Literal: <2, 3>[ <3>[ 1.000000e+00, 2.000000e+00, 3.000000e+00], <3>[ 4.000000e+00, 5.000000e+00, 6.000000e+00]] @{{.*}}ast.toy:12:11 +# CHECK-NEXT: VarDecl a<> @{{.*}}ast.toy:11:3 +# CHECK-NEXT: Literal: <2, 3>[ <3>[ 1.000000e+00, 2.000000e+00, 3.000000e+00], <3>[ 4.000000e+00, 5.000000e+00, 6.000000e+00]] @{{.*}}ast.toy:11:11 # CHECK-NEXT: VarDecl b<2, 3> @{{.*}}ast.toy:15:3 # CHECK-NEXT: Literal: <6>[ 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00] @{{.*}}ast.toy:15:17 -# CHECK-NEXT: VarDecl c<> @{{.*}}ast.toy:18:3 -# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:18:11 -# CHECK-NEXT: var: a @{{.*}}ast.toy:18:30 -# CHECK-NEXT: var: b @{{.*}}ast.toy:18:33 +# CHECK-NEXT: VarDecl c<> @{{.*}}ast.toy:19:3 +# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:19:11 +# CHECK-NEXT: var: a @{{.*}}ast.toy:19:30 +# CHECK-NEXT: var: b @{{.*}}ast.toy:19:33 # CHECK-NEXT: ] -# CHECK-NEXT: VarDecl d<> @{{.*}}ast.toy:21:3 -# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:21:11 -# CHECK-NEXT: var: b @{{.*}}ast.toy:21:30 -# CHECK-NEXT: var: a @{{.*}}ast.toy:21:33 +# CHECK-NEXT: VarDecl d<> @{{.*}}ast.toy:22:3 +# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:22:11 +# CHECK-NEXT: var: b @{{.*}}ast.toy:22:30 +# CHECK-NEXT: var: a @{{.*}}ast.toy:22:33 # CHECK-NEXT: ] -# CHECK-NEXT: VarDecl e<> @{{.*}}ast.toy:24:3 -# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:24:11 -# CHECK-NEXT: var: b @{{.*}}ast.toy:24:30 -# CHECK-NEXT: var: c @{{.*}}ast.toy:24:33 +# CHECK-NEXT: VarDecl e<> @{{.*}}ast.toy:25:3 +# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:25:11 +# CHECK-NEXT: var: b @{{.*}}ast.toy:25:30 +# CHECK-NEXT: var: c @{{.*}}ast.toy:25:33 # CHECK-NEXT: ] -# CHECK-NEXT: VarDecl e<> @{{.*}}ast.toy:27:3 -# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:27:11 -# CHECK-NEXT: Call 'transpose' [ @{{.*}}ast.toy:27:30 -# CHECK-NEXT: var: a @{{.*}}ast.toy:27:40 +# CHECK-NEXT: VarDecl e<> @{{.*}}ast.toy:28:3 +# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:28:11 +# CHECK-NEXT: Call 'transpose' [ @{{.*}}ast.toy:28:30 +# CHECK-NEXT: var: a @{{.*}}ast.toy:28:40 # CHECK-NEXT: ] -# CHECK-NEXT: var: c @{{.*}}ast.toy:27:44 +# CHECK-NEXT: var: c @{{.*}}ast.toy:28:44 # CHECK-NEXT: ] diff --git a/mlir/test/Examples/Toy/Ch5/affine-lowering.mlir b/mlir/test/Examples/Toy/Ch5/affine-lowering.mlir new file mode 100644 index 000000000000..74617f8fbc22 --- /dev/null +++ b/mlir/test/Examples/Toy/Ch5/affine-lowering.mlir @@ -0,0 +1,65 @@ +// RUN: toyc-ch5 %s -emit=mlir-affine 2>&1 | FileCheck %s +// RUN: toyc-ch5 %s -emit=mlir-affine -opt 2>&1 | FileCheck %s --check-prefix=OPT + +func @main() { + %0 = "toy.constant"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64> + %2 = "toy.transpose"(%0) : (tensor<2x3xf64>) -> tensor<2x3xf64> + %3 = "toy.mul"(%2, %2) : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<2x3xf64> + "toy.print"(%3) : (tensor<2x3xf64>) -> () + "toy.return"() : () -> () +} + +// CHECK-LABEL: func @main() +// CHECK: [[VAL_0:%.*]] = constant 1.000000e+00 : f64 +// CHECK: [[VAL_1:%.*]] = constant 2.000000e+00 : f64 +// CHECK: [[VAL_2:%.*]] = constant 3.000000e+00 : f64 +// CHECK: [[VAL_3:%.*]] = constant 4.000000e+00 : f64 +// CHECK: [[VAL_4:%.*]] = constant 5.000000e+00 : f64 +// CHECK: [[VAL_5:%.*]] = constant 6.000000e+00 : f64 +// CHECK: [[VAL_6:%.*]] = alloc() : memref<2x3xf64> +// CHECK: [[VAL_7:%.*]] = alloc() : memref<2x3xf64> +// CHECK: [[VAL_8:%.*]] = alloc() : memref<2x3xf64> +// CHECK: affine.store [[VAL_0]], [[VAL_8]][0, 0] : memref<2x3xf64> +// CHECK: affine.store [[VAL_1]], [[VAL_8]][0, 1] : memref<2x3xf64> +// CHECK: affine.store [[VAL_2]], [[VAL_8]][0, 2] : memref<2x3xf64> +// CHECK: affine.store [[VAL_3]], [[VAL_8]][1, 0] : memref<2x3xf64> +// CHECK: affine.store [[VAL_4]], [[VAL_8]][1, 1] : memref<2x3xf64> +// CHECK: affine.store [[VAL_5]], [[VAL_8]][1, 2] : memref<2x3xf64> +// CHECK: affine.for [[VAL_9:%.*]] = 0 to 2 { +// CHECK: affine.for [[VAL_10:%.*]] = 0 to 3 { +// CHECK: [[VAL_11:%.*]] = affine.load [[VAL_8]]{{\[}}[[VAL_10]], [[VAL_9]]] : memref<2x3xf64> +// CHECK: affine.store [[VAL_11]], [[VAL_7]]{{\[}}[[VAL_9]], [[VAL_10]]] : memref<2x3xf64> +// CHECK: affine.for [[VAL_12:%.*]] = 0 to 2 { +// CHECK: affine.for [[VAL_13:%.*]] = 0 to 3 { +// CHECK: [[VAL_14:%.*]] = affine.load [[VAL_7]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<2x3xf64> +// CHECK: [[VAL_15:%.*]] = affine.load [[VAL_7]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<2x3xf64> +// CHECK: [[VAL_16:%.*]] = mulf [[VAL_14]], [[VAL_15]] : f64 +// CHECK: affine.store [[VAL_16]], [[VAL_6]]{{\[}}[[VAL_12]], [[VAL_13]]] : memref<2x3xf64> +// CHECK: "toy.print"([[VAL_6]]) : (memref<2x3xf64>) -> () +// CHECK: dealloc [[VAL_8]] : memref<2x3xf64> +// CHECK: dealloc [[VAL_7]] : memref<2x3xf64> +// CHECK: dealloc [[VAL_6]] : memref<2x3xf64> + +// OPT-LABEL: func @main() +// OPT: [[VAL_1:%.*]] = constant 1.000000e+00 : f64 +// OPT: [[VAL_2:%.*]] = constant 2.000000e+00 : f64 +// OPT: [[VAL_3:%.*]] = constant 3.000000e+00 : f64 +// OPT: [[VAL_4:%.*]] = constant 4.000000e+00 : f64 +// OPT: [[VAL_5:%.*]] = constant 5.000000e+00 : f64 +// OPT: [[VAL_6:%.*]] = constant 6.000000e+00 : f64 +// OPT: [[VAL_7:%.*]] = alloc() : memref<2x3xf64> +// OPT: [[VAL_8:%.*]] = alloc() : memref<2x3xf64> +// OPT: affine.store [[VAL_1]], [[VAL_8]]{{\[}}0, 0] : memref<2x3xf64> +// OPT: affine.store [[VAL_2]], [[VAL_8]]{{\[}}0, 1] : memref<2x3xf64> +// OPT: affine.store [[VAL_3]], [[VAL_8]]{{\[}}0, 2] : memref<2x3xf64> +// OPT: affine.store [[VAL_4]], [[VAL_8]]{{\[}}1, 0] : memref<2x3xf64> +// OPT: affine.store [[VAL_5]], [[VAL_8]]{{\[}}1, 1] : memref<2x3xf64> +// OPT: affine.store [[VAL_6]], [[VAL_8]]{{\[}}1, 2] : memref<2x3xf64> +// OPT: affine.for [[VAL_9:%.*]] = 0 to 2 { +// OPT: affine.for [[VAL_10:%.*]] = 0 to 3 { +// OPT: [[VAL_11:%.*]] = affine.load [[VAL_8]]{{\[}}[[VAL_10]], [[VAL_9]]] : memref<2x3xf64> +// OPT: [[VAL_12:%.*]] = mulf [[VAL_11]], [[VAL_11]] : f64 +// OPT: affine.store [[VAL_12]], [[VAL_7]]{{\[}}[[VAL_9]], [[VAL_10]]] : memref<2x3xf64> +// OPT: "toy.print"([[VAL_7]]) : (memref<2x3xf64>) -> () +// OPT: dealloc [[VAL_8]] : memref<2x3xf64> +// OPT: dealloc [[VAL_7]] : memref<2x3xf64> diff --git a/mlir/test/Examples/Toy/Ch5/ast.toy b/mlir/test/Examples/Toy/Ch5/ast.toy index 96761165cf8b..5a4ecbbce532 100644 --- a/mlir/test/Examples/Toy/Ch5/ast.toy +++ b/mlir/test/Examples/Toy/Ch5/ast.toy @@ -1,7 +1,6 @@ # RUN: toyc-ch5 %s -emit=ast 2>&1 | FileCheck %s - -# User defined generic function that operates solely on +# User defined generic function that operates on unknown shaped arguments. def multiply_transpose(a, b) { return a * transpose(b); } @@ -10,9 +9,11 @@ def main() { # Define a variable `a` with shape <2, 3>, initialized with the literal value. # The shape is inferred from the supplied literal. var a = [[1, 2, 3], [4, 5, 6]]; - # b is identical to a, the literal array is implicitely reshaped: defining new - # variables is the way to reshape arrays (element count must match). + # b is identical to a, the literal array is implicitly reshaped: defining new + # variables is the way to reshape arrays (element count in literal must match + # the size of specified shape). var b<2, 3> = [1, 2, 3, 4, 5, 6]; + # This call will specialize `multiply_transpose` with <2, 3> for both # arguments and deduce a return type of <2, 2> in initialization of `c`. var c = multiply_transpose(a, b); @@ -30,44 +31,44 @@ def main() { # CHECK: Module: # CHECK-NEXT: Function -# CHECK-NEXT: Proto 'multiply_transpose' @{{.*}}ast.toy:5:1' +# CHECK-NEXT: Proto 'multiply_transpose' @{{.*}}ast.toy:4:1' # CHECK-NEXT: Params: [a, b] # CHECK-NEXT: Block { # CHECK-NEXT: Retur -# CHECK-NEXT: BinOp: * @{{.*}}ast.toy:6:14 -# CHECK-NEXT: var: a @{{.*}}ast.toy:6:10 -# CHECK-NEXT: Call 'transpose' [ @{{.*}}ast.toy:6:14 -# CHECK-NEXT: var: b @{{.*}}ast.toy:6:24 +# CHECK-NEXT: BinOp: * @{{.*}}ast.toy:5:14 +# CHECK-NEXT: var: a @{{.*}}ast.toy:5:10 +# CHECK-NEXT: Call 'transpose' [ @{{.*}}ast.toy:5:14 +# CHECK-NEXT: var: b @{{.*}}ast.toy:5:24 # CHECK-NEXT: ] # CHECK-NEXT: } // Block # CHECK-NEXT: Function -# CHECK-NEXT: Proto 'main' @{{.*}}ast.toy:9:1' +# CHECK-NEXT: Proto 'main' @{{.*}}ast.toy:8:1' # CHECK-NEXT: Params: [] # CHECK-NEXT: Block { -# CHECK-NEXT: VarDecl a<> @{{.*}}ast.toy:12:3 -# CHECK-NEXT: Literal: <2, 3>[ <3>[ 1.000000e+00, 2.000000e+00, 3.000000e+00], <3>[ 4.000000e+00, 5.000000e+00, 6.000000e+00]] @{{.*}}ast.toy:12:11 +# CHECK-NEXT: VarDecl a<> @{{.*}}ast.toy:11:3 +# CHECK-NEXT: Literal: <2, 3>[ <3>[ 1.000000e+00, 2.000000e+00, 3.000000e+00], <3>[ 4.000000e+00, 5.000000e+00, 6.000000e+00]] @{{.*}}ast.toy:11:11 # CHECK-NEXT: VarDecl b<2, 3> @{{.*}}ast.toy:15:3 # CHECK-NEXT: Literal: <6>[ 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00] @{{.*}}ast.toy:15:17 -# CHECK-NEXT: VarDecl c<> @{{.*}}ast.toy:18:3 -# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:18:11 -# CHECK-NEXT: var: a @{{.*}}ast.toy:18:30 -# CHECK-NEXT: var: b @{{.*}}ast.toy:18:33 +# CHECK-NEXT: VarDecl c<> @{{.*}}ast.toy:19:3 +# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:19:11 +# CHECK-NEXT: var: a @{{.*}}ast.toy:19:30 +# CHECK-NEXT: var: b @{{.*}}ast.toy:19:33 # CHECK-NEXT: ] -# CHECK-NEXT: VarDecl d<> @{{.*}}ast.toy:21:3 -# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:21:11 -# CHECK-NEXT: var: b @{{.*}}ast.toy:21:30 -# CHECK-NEXT: var: a @{{.*}}ast.toy:21:33 +# CHECK-NEXT: VarDecl d<> @{{.*}}ast.toy:22:3 +# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:22:11 +# CHECK-NEXT: var: b @{{.*}}ast.toy:22:30 +# CHECK-NEXT: var: a @{{.*}}ast.toy:22:33 # CHECK-NEXT: ] -# CHECK-NEXT: VarDecl e<> @{{.*}}ast.toy:24:3 -# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:24:11 -# CHECK-NEXT: var: b @{{.*}}ast.toy:24:30 -# CHECK-NEXT: var: c @{{.*}}ast.toy:24:33 +# CHECK-NEXT: VarDecl e<> @{{.*}}ast.toy:25:3 +# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:25:11 +# CHECK-NEXT: var: b @{{.*}}ast.toy:25:30 +# CHECK-NEXT: var: c @{{.*}}ast.toy:25:33 # CHECK-NEXT: ] -# CHECK-NEXT: VarDecl e<> @{{.*}}ast.toy:27:3 -# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:27:11 -# CHECK-NEXT: Call 'transpose' [ @{{.*}}ast.toy:27:30 -# CHECK-NEXT: var: a @{{.*}}ast.toy:27:40 +# CHECK-NEXT: VarDecl e<> @{{.*}}ast.toy:28:3 +# CHECK-NEXT: Call 'multiply_transpose' [ @{{.*}}ast.toy:28:11 +# CHECK-NEXT: Call 'transpose' [ @{{.*}}ast.toy:28:30 +# CHECK-NEXT: var: a @{{.*}}ast.toy:28:40 # CHECK-NEXT: ] -# CHECK-NEXT: var: c @{{.*}}ast.toy:27:44 +# CHECK-NEXT: var: c @{{.*}}ast.toy:28:44 # CHECK-NEXT: ] diff --git a/mlir/test/Examples/Toy/Ch5/codegen.toy b/mlir/test/Examples/Toy/Ch5/codegen.toy index 5008de2b92aa..50a1d8ff0893 100644 --- a/mlir/test/Examples/Toy/Ch5/codegen.toy +++ b/mlir/test/Examples/Toy/Ch5/codegen.toy @@ -13,20 +13,19 @@ def main() { print(d); } -# CHECK-LABEL: func @multiply_transpose(%arg0: !toy.array, %arg1: !toy.array) -# CHECK-NEXT: attributes {toy.generic = true} { -# CHECK-NEXT: %0 = "toy.transpose"(%arg1) : (!toy.array) -> !toy.array -# CHECK-NEXT: %1 = "toy.mul"(%arg0, %0) : (!toy.array, !toy.array) -> !toy.array -# CHECK-NEXT: "toy.return"(%1) : (!toy.array) -> () -# CHECK-NEXT: } - -# CHECK-LABEL: func @main() { -# CHECK-NEXT: %0 = "toy.constant"() {value = dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> !toy.array<2, 3> -# CHECK-NEXT: %1 = "toy.reshape"(%0) : (!toy.array<2, 3>) -> !toy.array<2, 3> -# CHECK-NEXT: %2 = "toy.constant"() {value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>} : () -> !toy.array<6> -# CHECK-NEXT: %3 = "toy.reshape"(%2) : (!toy.array<6>) -> !toy.array<2, 3> -# CHECK-NEXT: %4 = "toy.generic_call"(%1, %3) {callee = "multiply_transpose"} : (!toy.array<2, 3>, !toy.array<2, 3>) -> !toy.array -# CHECK-NEXT: %5 = "toy.generic_call"(%3, %1) {callee = "multiply_transpose"} : (!toy.array<2, 3>, !toy.array<2, 3>) -> !toy.array -# CHECK-NEXT: "toy.print"(%5) : (!toy.array) -> () -# CHECK-NEXT: "toy.return"() : () -> () +# CHECK-LABEL: func @multiply_transpose( +# CHECK-SAME: [[VAL_0:%.*]]: tensor<*xf64>, [[VAL_1:%.*]]: tensor<*xf64>) +# CHECK-NEXT: attributes {toy.generic} { +# CHECK-NEXT: [[VAL_2:%.*]] = "toy.transpose"([[VAL_1]]) : (tensor<*xf64>) -> tensor<*xf64> +# CHECK-NEXT: [[VAL_3:%.*]] = "toy.mul"([[VAL_0]], [[VAL_2]]) : (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64> +# CHECK-NEXT: "toy.return"([[VAL_3]]) : (tensor<*xf64>) -> () +# CHECK-LABEL: func @main() { +# CHECK-NEXT: [[VAL_4:%.*]] = "toy.constant"() {value = dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64> +# CHECK-NEXT: [[VAL_5:%.*]] = "toy.reshape"([[VAL_4]]) : (tensor<2x3xf64>) -> tensor<2x3xf64> +# CHECK-NEXT: [[VAL_6:%.*]] = "toy.constant"() {value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>} : () -> tensor<6xf64> +# CHECK-NEXT: [[VAL_7:%.*]] = "toy.reshape"([[VAL_6]]) : (tensor<6xf64>) -> tensor<2x3xf64> +# CHECK-NEXT: [[VAL_8:%.*]] = "toy.generic_call"([[VAL_5]], [[VAL_7]]) {callee = @multiply_transpose} : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> +# CHECK-NEXT: [[VAL_9:%.*]] = "toy.generic_call"([[VAL_7]], [[VAL_5]]) {callee = @multiply_transpose} : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> +# CHECK-NEXT: "toy.print"([[VAL_9]]) : (tensor<*xf64>) -> () +# CHECK-NEXT: "toy.return"() : () -> () diff --git a/mlir/test/Examples/Toy/Ch5/invalid.mlir b/mlir/test/Examples/Toy/Ch5/invalid.mlir index df8e2dfde786..73ade9a910eb 100644 --- a/mlir/test/Examples/Toy/Ch5/invalid.mlir +++ b/mlir/test/Examples/Toy/Ch5/invalid.mlir @@ -1,11 +1,9 @@ // RUN: not toyc-ch5 %s -emit=mlir 2>&1 - -// This IR is not "valid": +// The following IR is not "valid": // - toy.print should not return a value. // - toy.print should take an argument. // - There should be a block terminator. -// This all round-trip since this is opaque for MLIR. func @main() { - %0 = "toy.print"() : () -> !toy.array<2, 3> + %0 = "toy.print"() : () -> tensor<2x3xf64> } diff --git a/mlir/test/Examples/Toy/Ch5/lowering.toy b/mlir/test/Examples/Toy/Ch5/lowering.toy deleted file mode 100644 index 6f16437e9011..000000000000 --- a/mlir/test/Examples/Toy/Ch5/lowering.toy +++ /dev/null @@ -1,16 +0,0 @@ -# RUN: toyc-ch5 %s -emit=llvm-ir 2>&1 | FileCheck %s - -# User defined generic function that operates on unknown shaped arguments -def multiply_transpose(a, b) { - return a * transpose(b); -} - -# CHECK: define void @main() { -# CHECK: %1 = call i8* @malloc(i64 mul (i64 ptrtoint (double* getelementptr (double, double* null, i64 1) to i64), i64 6)) -def main() { - var a<2, 3> = [[1, 2, 3], [4, 5, 6]]; - var b<2, 3> = [1, 2, 3, 4, 5, 6]; - var c = multiply_transpose(a, b); - var d = multiply_transpose(b, a); - print(d); -} diff --git a/mlir/test/Examples/Toy/Ch5/scalar.toy b/mlir/test/Examples/Toy/Ch5/scalar.toy index 2e8d46f0d8ae..2743b5a3ac94 100644 --- a/mlir/test/Examples/Toy/Ch5/scalar.toy +++ b/mlir/test/Examples/Toy/Ch5/scalar.toy @@ -6,9 +6,9 @@ def main() { } # CHECK-LABEL: func @main() { -# CHECK-NEXT: %0 = "toy.constant"() {value = dense<5.500000e+00> : tensor<1xf64>} : () -> !toy.array<1> -# CHECK-NEXT: %1 = "toy.reshape"(%0) : (!toy.array<1>) -> !toy.array<2, 2> -# CHECK-NEXT: "toy.print"(%1) : (!toy.array<2, 2>) -> () +# CHECK-NEXT: %0 = "toy.constant"() {value = dense<5.500000e+00> : tensor} : () -> tensor +# CHECK-NEXT: %1 = "toy.reshape"(%0) : (tensor) -> tensor<2x2xf64> +# CHECK-NEXT: "toy.print"(%1) : (tensor<2x2xf64>) -> () # CHECK-NEXT: "toy.return"() : () -> () # CHECK-NEXT: } diff --git a/mlir/test/Examples/Toy/Ch5/shape_inference.mlir b/mlir/test/Examples/Toy/Ch5/shape_inference.mlir new file mode 100644 index 000000000000..7be7eda77822 --- /dev/null +++ b/mlir/test/Examples/Toy/Ch5/shape_inference.mlir @@ -0,0 +1,30 @@ +// RUN: toyc-ch5 %s -emit=mlir -opt 2>&1 | FileCheck %s + +// Check the result of inlining+shape inference on an input module. + +func @multiply_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> { + %0 = "toy.transpose"(%arg1) : (tensor<*xf64>) -> tensor<*xf64> + %1 = "toy.mul"(%arg0, %0) : (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64> + "toy.return"(%1) : (tensor<*xf64>) -> () +} +func @main() { + %0 = "toy.constant"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64> + %1 = "toy.reshape"(%0) : (tensor<2x3xf64>) -> tensor<2x3xf64> + %2 = "toy.constant"() {value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>} : () -> tensor<6xf64> + %3 = "toy.reshape"(%2) : (tensor<6xf64>) -> tensor<2x3xf64> + %4 = "toy.generic_call"(%1, %3) {callee = @multiply_transpose} : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> + %5 = "toy.generic_call"(%3, %1) {callee = @multiply_transpose} : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> + "toy.print"(%5) : (tensor<*xf64>) -> () + "toy.return"() : () -> () +} + +// CHECK-NOT: func @multiply_transpose +// CHECK-NOT: tensor<*xf64> + +// CHECK-LABEL: func @main() { +// CHECK: [[VAL_0:%.*]] = "toy.constant"() {value = dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64> +// CHECK: [[VAL_1:%.*]] = "toy.constant"() {value = dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64> +// CHECK: [[VAL_2:%.*]] = "toy.transpose"([[VAL_0]]) : (tensor<2x3xf64>) -> tensor<3x2xf64> +// CHECK: [[VAL_3:%.*]] = "toy.mul"([[VAL_1]], [[VAL_2]]) : (tensor<2x3xf64>, tensor<3x2xf64>) -> tensor<2x2xf64> +// CHECK: "toy.print"([[VAL_3]]) : (tensor<2x2xf64>) -> () +// CHECK: "toy.return"() : () -> () diff --git a/mlir/test/Examples/Toy/Ch5/transpose_transpose.toy b/mlir/test/Examples/Toy/Ch5/transpose_transpose.toy deleted file mode 100644 index 109cbd82c89f..000000000000 --- a/mlir/test/Examples/Toy/Ch5/transpose_transpose.toy +++ /dev/null @@ -1,19 +0,0 @@ -# RUN: toyc-ch5 %s -emit=mlir 2>&1 | FileCheck %s -# RUN: toyc-ch5 %s -emit=mlir -opt 2>&1 | FileCheck %s --check-prefix=OPT - -def transpose_transpose(x) { - return transpose(transpose(x)); -} - -def main() { - print(transpose_transpose([[1, 2], [3, 4]])); -} - -#CHECK-LABEL: func @transpose_transpose -#CHECK: transpose -#CHECK-LABEL: main - - -#OPT-LABEL: func @transpose_transpose -#OPT-NOT: transpose - diff --git a/mlir/test/Examples/Toy/Ch5/trivialReshape.toy b/mlir/test/Examples/Toy/Ch5/trivialReshape.toy deleted file mode 100644 index cb9946d8fa4b..000000000000 --- a/mlir/test/Examples/Toy/Ch5/trivialReshape.toy +++ /dev/null @@ -1,24 +0,0 @@ -# RUN: toyc-ch5 %s -emit=mlir 2>&1 | FileCheck %s -# RUN: toyc-ch5 %s -emit=mlir -opt 2>&1 | FileCheck %s --check-prefix=OPT - -# We expect no reshape in this function with optimizations enabled -def foo(a) { - var b<2,1> = a; - var c<2,1> = b; - print(c); -} - -def main() { - var a<2, 1> = [1, 2]; - foo(a); -} - -# without optimizations, match the reshape -#CHECK-LABEL: func @foo -#CHECK: reshape -#CHECK-LABEL: main - -# with optimizations, ensure no reshape -#OPT-LABEL: main -#OPT-LABEL: func @foo_2x1 -#OPT-NOT: reshape