forked from OSchip/llvm-project
Update Chapter 4 of the Toy tutorial
This Chapter now introduces and makes use of the Interface concept in MLIR to demonstrate ShapeInference. END_PUBLIC Closes tensorflow/mlir#191 PiperOrigin-RevId: 275085151
This commit is contained in:
parent
e88dbc8c95
commit
3940b90d84
|
@ -1,16 +1,29 @@
|
|||
add_subdirectory(include)
|
||||
|
||||
set(LLVM_LINK_COMPONENTS
|
||||
Support
|
||||
)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS mlir/ToyCombine.td)
|
||||
mlir_tablegen(ToyCombine.inc -gen-rewriters "-I${CMAKE_CURRENT_SOURCE_DIR}/include")
|
||||
add_public_tablegen_target(ToyCh4CombineIncGen)
|
||||
|
||||
add_toy_chapter(toyc-ch4
|
||||
toyc.cpp
|
||||
parser/AST.cpp
|
||||
mlir/MLIRGen.cpp
|
||||
mlir/ToyDialect.cpp
|
||||
mlir/Dialect.cpp
|
||||
mlir/DeadFunctionEliminationPass.cpp
|
||||
mlir/ShapeInferencePass.cpp
|
||||
mlir/ToyCombine.cpp
|
||||
)
|
||||
|
||||
add_dependencies(toyc-ch4 ToyCh4OpsIncGen)
|
||||
add_dependencies(toyc-ch4 ToyCh4ShapeInferenceInterfaceIncGen)
|
||||
add_dependencies(toyc-ch4 ToyCh4CombineIncGen)
|
||||
include_directories(include/)
|
||||
include_directories(${CMAKE_CURRENT_BINARY_DIR})
|
||||
include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/)
|
||||
target_link_libraries(toyc-ch4
|
||||
PRIVATE
|
||||
MLIRAnalysis
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
add_subdirectory(toy)
|
|
@ -33,10 +33,9 @@
|
|||
|
||||
namespace toy {
|
||||
|
||||
/// A variable
|
||||
/// A variable type with shape information.
|
||||
struct VarType {
|
||||
enum { TY_FLOAT, TY_INT } elt_ty;
|
||||
std::vector<int> shape;
|
||||
std::vector<int64_t> shape;
|
||||
};
|
||||
|
||||
/// Base class for all expression nodes.
|
||||
|
@ -50,9 +49,7 @@ public:
|
|||
Expr_Var,
|
||||
Expr_BinOp,
|
||||
Expr_Call,
|
||||
Expr_Print, // builtin
|
||||
Expr_If,
|
||||
Expr_For,
|
||||
Expr_Print,
|
||||
};
|
||||
|
||||
ExprAST(ExprASTKind kind, Location location)
|
||||
|
@ -85,7 +82,7 @@ public:
|
|||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Num; }
|
||||
};
|
||||
|
||||
///
|
||||
/// Expression class for a literal value.
|
||||
class LiteralExprAST : public ExprAST {
|
||||
std::vector<std::unique_ptr<ExprAST>> values;
|
||||
std::vector<int64_t> dims;
|
||||
|
@ -116,7 +113,7 @@ public:
|
|||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Var; }
|
||||
};
|
||||
|
||||
///
|
||||
/// Expression class for defining a variable.
|
||||
class VarDeclExprAST : public ExprAST {
|
||||
std::string name;
|
||||
VarType type;
|
||||
|
@ -136,7 +133,7 @@ public:
|
|||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_VarDecl; }
|
||||
};
|
||||
|
||||
///
|
||||
/// Expression class for a return operator.
|
||||
class ReturnExprAST : public ExprAST {
|
||||
llvm::Optional<std::unique_ptr<ExprAST>> expr;
|
||||
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
set(LLVM_TARGET_DEFINITIONS Ops.td)
|
||||
mlir_tablegen(Ops.h.inc -gen-op-decls "-I${CMAKE_CURRENT_SOURCE_DIR}/..")
|
||||
mlir_tablegen(Ops.cpp.inc -gen-op-defs "-I${CMAKE_CURRENT_SOURCE_DIR}/..")
|
||||
add_public_tablegen_target(ToyCh4OpsIncGen)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS ShapeInferenceInterface.td)
|
||||
mlir_tablegen(ShapeInferenceOpInterfaces.h.inc -gen-op-interface-decls)
|
||||
mlir_tablegen(ShapeInferenceOpInterfaces.cpp.inc -gen-op-interface-defs)
|
||||
add_public_tablegen_target(ToyCh4ShapeInferenceInterfaceIncGen)
|
|
@ -16,7 +16,7 @@
|
|||
// =============================================================================
|
||||
//
|
||||
// This file implements the IR Dialect for the Toy language.
|
||||
// See g3doc/Tutorials/Toy/Ch-3.md for more information.
|
||||
// See g3doc/Tutorials/Toy/Ch-2.md for more information.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
|
@ -25,325 +25,30 @@
|
|||
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/TypeSupport.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
||||
namespace mlir {
|
||||
class Builder;
|
||||
}
|
||||
|
||||
namespace toy {
|
||||
|
||||
/// This is the definition of the Toy dialect. A dialect inherits from
|
||||
/// mlir::Dialect and register custom operations and types (in its constructor).
|
||||
/// It can also overriding general behavior of dialects exposed as virtual
|
||||
/// method, for example regarding verification and parsing/printing.
|
||||
/// mlir::Dialect and registers custom attributes, operations, and types (in its
|
||||
/// constructor). It can also override some general behavior exposed via virtual
|
||||
/// methods.
|
||||
class ToyDialect : public mlir::Dialect {
|
||||
public:
|
||||
explicit ToyDialect(mlir::MLIRContext *ctx);
|
||||
|
||||
/// Parse a type registered to this dialect. Overriding this method is
|
||||
/// required for dialects that have custom types.
|
||||
/// Technically this is only needed to be able to round-trip to textual IR.
|
||||
mlir::Type parseType(llvm::StringRef tyData,
|
||||
mlir::Location loc) const override;
|
||||
|
||||
/// Print a type registered to this dialect. Overriding this method is
|
||||
/// only required for dialects that have custom types.
|
||||
/// Technically this is only needed to be able to round-trip to textual IR.
|
||||
void printType(mlir::Type type, llvm::raw_ostream &os) const override;
|
||||
/// Provide a utility accessor to the dialect namespace. This is used by
|
||||
/// several utilities for casting between dialects.
|
||||
static llvm::StringRef getDialectNamespace() { return "toy"; }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/////////////////////// Custom Types for the Dialect ///////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace detail {
|
||||
struct ToyArrayTypeStorage;
|
||||
}
|
||||
|
||||
/// LLVM-style RTTI: one entry per subclass to allow dyn_cast/isa.
|
||||
enum ToyTypeKind {
|
||||
// The enum starts at the range reserved for this dialect.
|
||||
TOY_TYPE = mlir::Type::FIRST_TOY_TYPE,
|
||||
TOY_ARRAY,
|
||||
};
|
||||
|
||||
/// Type for Toy arrays.
|
||||
/// In MLIR Types are reference to immutable and uniqued objects owned by the
|
||||
/// MLIRContext. As such `ToyArrayType` only wraps a pointer to an uniqued
|
||||
/// instance of `ToyArrayTypeStorage` (defined in our implementation file) and
|
||||
/// provides the public facade API to interact with the type.
|
||||
class ToyArrayType : public mlir::Type::TypeBase<ToyArrayType, mlir::Type,
|
||||
detail::ToyArrayTypeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
/// Returns the dimensions for this array, or and empty range for a generic
|
||||
/// array.
|
||||
llvm::ArrayRef<int64_t> getShape();
|
||||
|
||||
/// Predicate to test if this array is generic (shape haven't been inferred
|
||||
/// yet).
|
||||
bool isGeneric() { return getShape().empty(); }
|
||||
|
||||
/// Return the rank of this array (0 if it is generic).
|
||||
int getRank() { return getShape().size(); }
|
||||
|
||||
/// Return the type of individual elements in the array.
|
||||
mlir::Type getElementType();
|
||||
|
||||
/// Get the unique instance of this Type from the context.
|
||||
/// A ToyArrayType is only defined by the shape of the array.
|
||||
static ToyArrayType get(mlir::MLIRContext *context,
|
||||
llvm::ArrayRef<int64_t> shape = {});
|
||||
|
||||
/// Support method to enable LLVM-style RTTI type casting.
|
||||
static bool kindof(unsigned kind) { return kind == ToyTypeKind::TOY_ARRAY; }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//////////////////// Custom Operations for the Dialect /////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Constant operation turns a literal into an SSA value. The data is attached
|
||||
/// to the operation as an attribute. For example:
|
||||
///
|
||||
/// %0 = "toy.constant"()
|
||||
/// {value: dense<tensor<2x3xf64>, [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]>}
|
||||
/// : () -> !toy.array<2, 3>
|
||||
///
|
||||
/// An operation inherits from `class Op` and specifies optional traits. Here we
|
||||
/// indicate that `toy.constant` does not have any operands and returns a single
|
||||
/// result. The traits provide some utilities methods for the operation, for
|
||||
/// instance we will be able to use `getResult()`, but `getOperand()` won't be
|
||||
/// available.
|
||||
class ConstantOp : public mlir::Op<ConstantOp, mlir::OpTrait::ZeroOperands,
|
||||
mlir::OpTrait::OneResult,
|
||||
mlir::OpTrait::HasNoSideEffect> {
|
||||
public:
|
||||
/// This is the name used by MLIR to match an operation to this class during
|
||||
/// parsing.
|
||||
static llvm::StringRef getOperationName() { return "toy.constant"; }
|
||||
|
||||
/// The operation can have extra verification beyond the traits they define.
|
||||
mlir::LogicalResult verify();
|
||||
|
||||
/// Interface to mlir::Builder::create<PrintOp>(...)
|
||||
/// This method populates the `state` that MLIR uses to create operations.
|
||||
/// The `toy.constant` operation does not have arguments but attaches a
|
||||
/// constant array as an attribute and returns it as an SSA value.
|
||||
static void build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
llvm::ArrayRef<int64_t> shape,
|
||||
mlir::DenseElementsAttr value);
|
||||
|
||||
/// Similar to the one above, but takes a single float and returns a
|
||||
/// !toy.array<1>.
|
||||
static void build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
mlir::FloatAttr value);
|
||||
|
||||
/// Inherit constructor.
|
||||
using Op::Op;
|
||||
};
|
||||
|
||||
/// Generic calls represent calls to a user defined function that needs to
|
||||
/// be specialized for the shape of its arguments. The callee name is attached
|
||||
/// as a literal string as an attribute. The arguments list must match the
|
||||
/// arguments expected by the callee. For example:
|
||||
///
|
||||
/// %4 = "toy.generic_call"(%1, %3) {callee: "my_func"}
|
||||
/// : (!toy.array<2, 3>, !toy.array<2, 3>) -> !toy<"array">
|
||||
///
|
||||
/// This is only valid if a function named "my_func" exists and takes two
|
||||
/// arguments.
|
||||
class GenericCallOp
|
||||
: public mlir::Op<GenericCallOp, mlir::OpTrait::VariadicOperands,
|
||||
mlir::OpTrait::OneResult> {
|
||||
public:
|
||||
/// MLIR will use this to register the operation with the parser/printer.
|
||||
static llvm::StringRef getOperationName() { return "toy.generic_call"; }
|
||||
|
||||
/// Operations can add custom verification beyond the traits they define.
|
||||
mlir::LogicalResult verify();
|
||||
|
||||
/// Interface to the builder to allow:
|
||||
/// mlir::Builder::create<GenericCallOp>(...)
|
||||
/// This method populate the `state` that MLIR use to create operations.
|
||||
/// The `toy.generic_call` operation accepts a callee name and a list of
|
||||
/// arguments for the call.
|
||||
static void build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
llvm::StringRef callee,
|
||||
llvm::ArrayRef<mlir::Value *> arguments);
|
||||
|
||||
/// Return the name of the callee.
|
||||
llvm::StringRef getCalleeName();
|
||||
|
||||
/// Inherit constructor.
|
||||
using Op::Op;
|
||||
};
|
||||
|
||||
/// Return operations terminate blocks (and functions as well). They take a
|
||||
/// single argument and the type must match the function return type.
|
||||
class ReturnOp
|
||||
: public mlir::Op<ReturnOp, mlir::OpTrait::VariadicOperands,
|
||||
mlir::OpTrait::ZeroResult, mlir::OpTrait::IsTerminator> {
|
||||
public:
|
||||
static llvm::StringRef getOperationName() { return "toy.return"; }
|
||||
|
||||
/// Operations can add custom verification beyond the traits they define.
|
||||
mlir::LogicalResult verify();
|
||||
|
||||
/// Interface to mlir::Builder::create<PrintOp>(...)
|
||||
/// This method populate the `state` that MLIR use to create operations.
|
||||
/// The `toy.return` operation accepts an optional single array as an argument
|
||||
/// and does not have any returned value.
|
||||
static void build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
mlir::Value *value = nullptr);
|
||||
|
||||
/// Return true if there is a returned value.
|
||||
bool hasOperand() { return 0 != getNumOperands(); }
|
||||
|
||||
/// Helper to return the optional operand. Caller must check if the operand
|
||||
/// is present before calling this.
|
||||
mlir::Value *getOperand() { return getOperation()->getOperand(0); }
|
||||
|
||||
/// Inherit constructor.
|
||||
using Op::Op;
|
||||
};
|
||||
|
||||
/// The print builtin takes a single array argument and does not return any.
|
||||
class PrintOp : public mlir::Op<PrintOp, mlir::OpTrait::OneOperand,
|
||||
mlir::OpTrait::ZeroResult> {
|
||||
public:
|
||||
static llvm::StringRef getOperationName() { return "toy.print"; }
|
||||
|
||||
/// Operations can add custom verification beyond the traits they define.
|
||||
mlir::LogicalResult verify();
|
||||
|
||||
/// Interface to mlir::Builder::create<PrintOp>(...)
|
||||
/// This method populate the `state` that MLIR use to create operations.
|
||||
/// The `toy.print` operation accepts a single array as argument and does
|
||||
/// not have any returned value.
|
||||
static void build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
mlir::Value *value);
|
||||
|
||||
/// Inherit constructor.
|
||||
using Op::Op;
|
||||
};
|
||||
|
||||
class TransposeOp : public mlir::Op<TransposeOp, mlir::OpTrait::OneOperand,
|
||||
mlir::OpTrait::OneResult,
|
||||
mlir::OpTrait::HasNoSideEffect> {
|
||||
public:
|
||||
static llvm::StringRef getOperationName() { return "toy.transpose"; }
|
||||
|
||||
/// Operation can add custom verification beyond the traits they define.
|
||||
mlir::LogicalResult verify();
|
||||
|
||||
/// Interface to mlir::Builder::create<TransposeOp>(...)
|
||||
/// This method populate the `state` that MLIR use to create operations.
|
||||
/// The `toy.transpose` operation accepts a single array as argument and
|
||||
/// returns the transposed array as its only result.
|
||||
static void build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
mlir::Value *value);
|
||||
|
||||
// Register our patterns for rewrite by the Canonicalization framework.
|
||||
static void
|
||||
getCanonicalizationPatterns(mlir::OwningRewritePatternList &results,
|
||||
mlir::MLIRContext *context);
|
||||
|
||||
/// Inherit constructor.
|
||||
using Op::Op;
|
||||
};
|
||||
|
||||
/// Reshape operation is transforming its input array into a new array with the
|
||||
/// same number of elements but different shapes. For example:
|
||||
///
|
||||
/// %0 = "toy.reshape"(%arg1) : (!toy.array<10>) -> !toy.array<5, 2>
|
||||
///
|
||||
class ReshapeOp : public mlir::Op<ReshapeOp, mlir::OpTrait::OneOperand,
|
||||
mlir::OpTrait::OneResult,
|
||||
mlir::OpTrait::HasNoSideEffect> {
|
||||
public:
|
||||
static llvm::StringRef getOperationName() { return "toy.reshape"; }
|
||||
|
||||
/// Operation can add custom verification beyond the traits they define.
|
||||
mlir::LogicalResult verify();
|
||||
|
||||
/// Interface to mlir::Builder::create<ReshapeOp>(...)
|
||||
/// This method populate the `state` that MLIR use to create operations.
|
||||
/// The `toy.reshape` operation accepts a single array as argument and
|
||||
/// returns the array with the specified reshapedType as its only result.
|
||||
static void build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
mlir::Value *value, ToyArrayType reshapedType);
|
||||
|
||||
// Register our patterns for rewrite by the Canonicalization framework.
|
||||
static void
|
||||
getCanonicalizationPatterns(mlir::OwningRewritePatternList &results,
|
||||
mlir::MLIRContext *context);
|
||||
|
||||
/// Inherit constructor.
|
||||
using Op::Op;
|
||||
};
|
||||
|
||||
/// Binary operation implementing a multiplication. For two-dimensional array
|
||||
/// a matrix multiplication is implemented, while for one dimensional array a
|
||||
/// dot product is performed.
|
||||
class MulOp : public mlir::Op<MulOp, mlir::OpTrait::NOperands<2>::Impl,
|
||||
mlir::OpTrait::OneResult,
|
||||
mlir::OpTrait::HasNoSideEffect> {
|
||||
public:
|
||||
static llvm::StringRef getOperationName() { return "toy.mul"; }
|
||||
|
||||
/// Operation can add custom verification beyond the traits they define.
|
||||
mlir::LogicalResult verify();
|
||||
|
||||
/// Interface to mlir::Builder::create<PrintOp>(...)
|
||||
/// This method populate the `state` that MLIR use to create operations.
|
||||
/// The `toy.mul` operation accepts two operands as argument and returns
|
||||
/// a single value.
|
||||
static void build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
mlir::Value *lhs, mlir::Value *rhs);
|
||||
|
||||
/// Convenience accessor for LHS of the expression.
|
||||
mlir::Value *getLHS() { return getOperand(0); }
|
||||
|
||||
/// Convenience accessor for RHS of the expression.
|
||||
mlir::Value *getRHS() { return getOperand(1); }
|
||||
|
||||
/// Inherit constructor.
|
||||
using Op::Op;
|
||||
};
|
||||
|
||||
/// Element wise addition of two arrays. The shape must match.
|
||||
class AddOp : public mlir::Op<AddOp, mlir::OpTrait::NOperands<2>::Impl,
|
||||
mlir::OpTrait::OneResult,
|
||||
mlir::OpTrait::HasNoSideEffect> {
|
||||
public:
|
||||
static llvm::StringRef getOperationName() { return "toy.add"; }
|
||||
|
||||
/// Operation can add custom verification beyond the traits they define.
|
||||
mlir::LogicalResult verify();
|
||||
|
||||
/// Interface to mlir::Builder::create<PrintOp>(...)
|
||||
/// This method populate the `state` that MLIR use to create operations.
|
||||
/// The `toy.mul` operation accepts two operands as argument and returns
|
||||
/// a single value.
|
||||
static void build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
mlir::Value *lhs, mlir::Value *rhs);
|
||||
|
||||
/// Convenience accessor for LHS of the expression.
|
||||
mlir::Value *getLHS() { return getOperand(0); }
|
||||
|
||||
/// Convenience accessor for RHS of the expression.
|
||||
mlir::Value *getRHS() { return getOperand(1); }
|
||||
|
||||
/// Inherit constructor.
|
||||
using Op::Op;
|
||||
};
|
||||
/// Include the auto-generated header file containing the declarations of the
|
||||
/// toy operations.
|
||||
#define GET_OP_CLASSES
|
||||
#include "toy/Ops.h.inc"
|
||||
|
||||
} // end namespace toy
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_TUTORIAL_TOY_DIALECT_H_
|
||||
|
|
|
@ -31,7 +31,7 @@ namespace toy {
|
|||
|
||||
/// Structure definition a location in a file.
|
||||
struct Location {
|
||||
std::shared_ptr<std::string> file; ///< filename
|
||||
std::shared_ptr<std::string> file; ///< filename.
|
||||
int line; ///< line number.
|
||||
int col; ///< column number.
|
||||
};
|
||||
|
|
|
@ -0,0 +1,285 @@
|
|||
//===- Ops.td - Toy dialect operation definitions ----------*- tablegen -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// Defines the operations of the Toy dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifdef TOY_OPS
|
||||
#else
|
||||
#define TOY_OPS
|
||||
|
||||
#ifdef SHAPE_INFERENCE_INTERFACE
|
||||
#else
|
||||
include "toy/ShapeInferenceInterface.td"
|
||||
#endif // SHAPE_INFERENCE_INTERFACE
|
||||
|
||||
// Provide a definition of the 'toy' dialect in the ODS framework so that we
|
||||
// can define our operations.
|
||||
def Toy_Dialect : Dialect {
|
||||
let name = "toy";
|
||||
let cppNamespace = "toy";
|
||||
}
|
||||
|
||||
// Base class for toy dialect operations. This operation inherits from the base
|
||||
// `Op` class in OpBase.td, and provides:
|
||||
// * The parent dialect of the operation.
|
||||
// * The mnemonic for the operation, or the name without the dialect prefix.
|
||||
// * A list of traits for the operation.
|
||||
class Toy_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<Toy_Dialect, mnemonic, traits>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Toy Operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// We define a toy operation by inherting from our base 'Toy_Op' class above.
|
||||
// Here we provide the mnemonic and a list of traits for the operation. The
|
||||
// constant operation is marked as 'NoSideEffect' as it is a pure operation
|
||||
// and may be removed if dead.
|
||||
def ConstantOp : Toy_Op<"constant", [NoSideEffect]> {
|
||||
// Provide a summary and description for this operation. This can be used to
|
||||
// auto-generate documenatation of the operations within our dialect.
|
||||
let summary = "constant";
|
||||
let description = [{
|
||||
Constant operation turns a literal into an SSA value. The data is attached
|
||||
to the operation as an attribute. For example:
|
||||
|
||||
```mlir
|
||||
%0 = "toy.constant"()
|
||||
{ value = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf64> }
|
||||
: () -> tensor<2x3xf64>
|
||||
```
|
||||
}];
|
||||
|
||||
// The constant operation takes an attribute as the only input.
|
||||
let arguments = (ins F64ElementsAttr:$value);
|
||||
|
||||
// The constant operation returns a single value of TensorType.
|
||||
let results = (outs F64Tensor);
|
||||
|
||||
// Add custom build methods for the constant operation. These method populates
|
||||
// the `state` that MLIR uses to create operations, i.e. these are used when
|
||||
// using `builder.create<ConstantOp>(...)`.
|
||||
let builders = [
|
||||
// Build a constant with a given constant tensor value.
|
||||
OpBuilder<"Builder *builder, OperationState &result, "
|
||||
"DenseElementsAttr value", [{
|
||||
build(builder, result, value.getType(), value);
|
||||
}]>,
|
||||
|
||||
// Build a constant with a given constant floating-point value.
|
||||
OpBuilder<"Builder *builder, OperationState &result, double value", [{
|
||||
buildConstantOp(builder, result, value);
|
||||
}]>
|
||||
];
|
||||
|
||||
// Invoke a static verify method to verify this constant operation.
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
}
|
||||
|
||||
def AddOp : Toy_Op<"add", [NoSideEffect]> {
|
||||
let summary = "element-wise addition operation";
|
||||
let description = [{
|
||||
The "add" operation performs element-wise addition between two tensors.
|
||||
The shapes of the tensor operands are expected to match.
|
||||
}];
|
||||
|
||||
let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs);
|
||||
let results = (outs F64Tensor);
|
||||
|
||||
// Allow building an AddOp with from the two input operands.
|
||||
let builders = [
|
||||
OpBuilder<"Builder *b, OperationState &result, Value *lhs, Value *rhs", [{
|
||||
buildAddOp(b, result, lhs, rhs);
|
||||
}]
|
||||
>];
|
||||
let extraClassDeclaration = [{
|
||||
void inferShapes() {
|
||||
getResult()->setType(getOperand(0)->getType());
|
||||
return;
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def GenericCallOp : Toy_Op<"generic_call"> {
|
||||
let summary = "generic call operation";
|
||||
let description = [{
|
||||
Generic calls represent calls to a user defined function that needs to
|
||||
be specialized for the shape of its arguments. The callee name is attached
|
||||
as a symbol reference via an attribute. The arguments list must match the
|
||||
arguments expected by the callee. For example:
|
||||
|
||||
```mlir
|
||||
%4 = "toy.generic_call"(%1, %3) {callee = @my_func}
|
||||
: (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64>
|
||||
```
|
||||
|
||||
This is only valid if a function named "my_func" exists and takes two
|
||||
arguments.
|
||||
}];
|
||||
|
||||
// The generic call operation takes a symbol reference attribute as the
|
||||
// callee, and inputs for the call.
|
||||
let arguments = (ins SymbolRefAttr:$callee, Variadic<F64Tensor>:$inputs);
|
||||
|
||||
// The generic call operation returns a single value of TensorType.
|
||||
let results = (outs F64Tensor);
|
||||
|
||||
// Add custom build methods for the generic call operation.
|
||||
let builders = [
|
||||
// Build a constant with a given constant tensor value.
|
||||
OpBuilder<"Builder *builder, OperationState &result, "
|
||||
"StringRef callee, ArrayRef<Value *> arguments", [{
|
||||
buildGenericCallOp(builder, result, callee, arguments);
|
||||
}]>
|
||||
];
|
||||
}
|
||||
|
||||
def MulOp : Toy_Op<"mul", [NoSideEffect]> {
|
||||
let summary = "element-wise multiplication operation";
|
||||
let description = [{
|
||||
The "mul" operation performs element-wise multiplication between two
|
||||
tensors. The shapes of the tensor operands are expected to match.
|
||||
}];
|
||||
|
||||
let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs);
|
||||
let results = (outs F64Tensor);
|
||||
|
||||
// Allow building a MulOp with from the two input operands.
|
||||
let builders = [
|
||||
OpBuilder<"Builder *b, OperationState &result, Value *lhs, Value *rhs", [{
|
||||
buildMulOp(b, result, lhs, rhs);
|
||||
}]
|
||||
>];
|
||||
let extraClassDeclaration = [{
|
||||
void inferShapes() {
|
||||
auto lhs = getOperand(0)->getType().cast<RankedTensorType>();
|
||||
auto rhs = getOperand(1)->getType().cast<RankedTensorType>();
|
||||
auto lhsRank = lhs.getShape().size();
|
||||
auto rhsRank = rhs.getShape().size();
|
||||
if (lhsRank != rhsRank) {
|
||||
return;
|
||||
}
|
||||
SmallVector<int64_t, 2> dims;
|
||||
if (lhsRank == 1) {
|
||||
// dot product, result shape is <1>
|
||||
dims.push_back(1);
|
||||
} else {
|
||||
if (lhsRank != 2) {
|
||||
return;
|
||||
}
|
||||
dims.push_back(lhs.getShape()[0]);
|
||||
dims.push_back(rhs.getShape()[1]);
|
||||
}
|
||||
getResult()->setType(RankedTensorType::get(dims, lhs.getElementType()));
|
||||
return;
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def PrintOp : Toy_Op<"print"> {
|
||||
let summary = "print operation";
|
||||
let description = [{
|
||||
The "print" builtin operation prints a given input tensor, and produces
|
||||
no results.
|
||||
}];
|
||||
|
||||
// The print operation takes an input tensor to print.
|
||||
let arguments = (ins F64Tensor:$input);
|
||||
}
|
||||
|
||||
def ReshapeOp : Toy_Op<"reshape", [NoSideEffect]> {
|
||||
let summary = "tensor reshape operation";
|
||||
let description = [{
|
||||
Reshape operation is transforming its input tensor into a new tensor with
|
||||
the same number of elements but different shapes. For example:
|
||||
|
||||
```mlir
|
||||
%0 = "toy.reshape"(%arg1) : (tensor<10xf64>) -> tensor<5x2xf64>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins F64Tensor:$input);
|
||||
let hasCanonicalizer = 1;
|
||||
|
||||
// We expect that the reshape operation returns a statically shaped tensor.
|
||||
let results = (outs StaticShapeTensorOf<[F64]>);
|
||||
}
|
||||
|
||||
def ReturnOp : Toy_Op<"return", [Terminator, HasParent<"FuncOp">]> {
|
||||
let summary = "return operation";
|
||||
let description = [{
|
||||
The "return" operation represents a return operation within a function.
|
||||
The operation takes an optional tensor operand and produces no results.
|
||||
The operand type must match the signature of the function that contains
|
||||
the operation. For example:
|
||||
|
||||
```mlir
|
||||
func @foo() -> tensor<2xf64> {
|
||||
...
|
||||
toy.return %0 : tensor<2xf64>
|
||||
}
|
||||
```
|
||||
}];
|
||||
|
||||
// The return operation takes an optional input operand to return. This
|
||||
// value must match the return type of the enclosing function.
|
||||
let arguments = (ins Variadic<F64Tensor>:$input);
|
||||
|
||||
// Allow building a ReturnOp with no return operand.
|
||||
let builders = [OpBuilder<
|
||||
"Builder *b, OperationState &result", [{ build(b, result, llvm::None); }]
|
||||
>];
|
||||
|
||||
// Provide extra utility definitions on the c++ operation class definition.
|
||||
let extraClassDeclaration = [{
|
||||
bool hasOperand() { return getNumOperands() != 0; }
|
||||
}];
|
||||
|
||||
// Invoke a static verify method to verify this return operation.
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
}
|
||||
|
||||
def TransposeOp : Toy_Op<"transpose", [NoSideEffect]> {
|
||||
let summary = "transpose operation";
|
||||
|
||||
let arguments = (ins F64Tensor:$input);
|
||||
let results = (outs F64Tensor);
|
||||
let hasCanonicalizer = 1;
|
||||
|
||||
// Allow building a TransposeOp with from the two input operands.
|
||||
let builders = [
|
||||
OpBuilder<"Builder *b, OperationState &result, Value *input", [{
|
||||
buildTransposeOp(b, result, input);
|
||||
}]
|
||||
>];
|
||||
let extraClassDeclaration = [{
|
||||
void inferShapes() {
|
||||
SmallVector<int64_t, 2> dims;
|
||||
auto arrayTy = getOperand()->getType().cast<RankedTensorType>();
|
||||
dims.insert(dims.end(), arrayTy.getShape().begin(),
|
||||
arrayTy.getShape().end());
|
||||
if (dims.size() == 2)
|
||||
std::swap(dims[0], dims[1]);
|
||||
getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
||||
return;
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // TOY_OPS
|
|
@ -26,10 +26,11 @@
|
|||
|
||||
namespace mlir {
|
||||
class Pass;
|
||||
} // namespace mlir
|
||||
|
||||
namespace toy {
|
||||
std::unique_ptr<mlir::Pass> createShapeInferencePass();
|
||||
} // namespace toy
|
||||
std::unique_ptr<Pass> createShapeInferencePass();
|
||||
std::unique_ptr<Pass> createDeadFunctionEliminationPass();
|
||||
} // end namespace toy
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_TUTORIAL_TOY_PASSES_H
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
//===- ShapeInferenceInterface.td - Operation Interface for Shape Inference ----------*- tablegen -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// Defines the operations of the Shape Inference Op Interface.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifdef SHAPE_INFERENCE_INTERFACE
|
||||
#else
|
||||
#define SHAPE_INFERENCE_INTERFACE
|
||||
|
||||
#ifdef OP_BASE
|
||||
#else
|
||||
include "mlir/IR/OpBase.td"
|
||||
#endif // OP_BASE
|
||||
|
||||
def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> {
|
||||
let methods = [
|
||||
InterfaceMethod<"Infer output shape for the current operation.",
|
||||
"void", "inferShapes", (ins), [{}]>
|
||||
];
|
||||
}
|
||||
|
||||
#endif // SHAPE_INFERENCE_INTERFACE
|
|
@ -0,0 +1,61 @@
|
|||
//===- DeadFunctionEliminationPass.cpp - Eliminate inlined functions ------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements a Module level pass performing dead function
|
||||
// elimination. This is required as a post-processing step after function
|
||||
// inlining.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Analysis/Verifier.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "toy/Passes.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringSet.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include <algorithm>
|
||||
|
||||
namespace {
|
||||
class DeadFunctionEliminationPass
|
||||
: public mlir::ModulePass<DeadFunctionEliminationPass> {
|
||||
public:
|
||||
void runOnModule() override {
|
||||
std::string str = "main";
|
||||
auto module = getModule();
|
||||
for (auto &f : module) {
|
||||
// eliminate dead functions that are not main
|
||||
if (str.find(f.getName().getStringRef()) == std::string::npos)
|
||||
f.erase();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
/// Create a pass that eliminates inlined functions in toy.
|
||||
std::unique_ptr<mlir::Pass> mlir::toy::createDeadFunctionEliminationPass() {
|
||||
return std::make_unique<DeadFunctionEliminationPass>();
|
||||
}
|
|
@ -0,0 +1,190 @@
|
|||
//===- Dialect.cpp - Toy IR Dialect registration in MLIR ------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements the dialect for the Toy IR: custom type parsing and
|
||||
// operation verification.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "toy/Dialect.h"
|
||||
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Transforms/InliningUtils.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::toy;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ToyInlinerInterface
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This class defines the interface for handling inlining with Toy
|
||||
/// operations.
|
||||
struct ToyInlinerInterface : public DialectInlinerInterface {
|
||||
using DialectInlinerInterface::DialectInlinerInterface;
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Analysis Hooks
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// All operations within toy can be inlined.
|
||||
bool isLegalToInline(Operation *, Region *,
|
||||
BlockAndValueMapping &) const final {
|
||||
return true;
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Transformation Hooks
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// Handle the given inlined terminator(toy.return) by replacing it with a new
|
||||
/// operation as necessary.
|
||||
void handleTerminator(Operation *op,
|
||||
ArrayRef<Value *> valuesToRepl) const final {
|
||||
// Only "toy.return" needs to be handled here.
|
||||
auto returnOp = cast<ReturnOp>(op);
|
||||
|
||||
// Replace the values directly with the return operands.
|
||||
assert(returnOp.getNumOperands() == valuesToRepl.size());
|
||||
for (const auto &it : llvm::enumerate(returnOp.getOperands()))
|
||||
valuesToRepl[it.index()]->replaceAllUsesWith(it.value());
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ToyDialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Dialect creation, the instance will be owned by the context. This is the
|
||||
/// point of registration of custom types and operations for the dialect.
|
||||
ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "toy/Ops.cpp.inc"
|
||||
>();
|
||||
addInterfaces<ToyInlinerInterface>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Toy Operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Build a constant operation.
|
||||
/// The builder is passed as an argument, so is the state that this method is
|
||||
/// expected to fill in order to build the operation.
|
||||
static void buildConstantOp(mlir::Builder *builder, mlir::OperationState &state,
|
||||
double value) {
|
||||
auto dataType = builder->getTensorType({}, builder->getF64Type());
|
||||
auto dataAttribute = DenseElementsAttr::get(dataType, value);
|
||||
ConstantOp::build(builder, state, dataType, dataAttribute);
|
||||
}
|
||||
|
||||
/// Verifier for constant operation.
|
||||
static mlir::LogicalResult verify(ConstantOp op) {
|
||||
// If the return type of the constant is not an unranked tensor, the shape
|
||||
// must match the shape of the attribute holding the data.
|
||||
auto resultType = op.getResult()->getType().cast<RankedTensorType>();
|
||||
if (!resultType)
|
||||
return success();
|
||||
|
||||
auto attrType = op.value().getType().cast<mlir::TensorType>();
|
||||
if (attrType.getRank() != resultType.getRank()) {
|
||||
return op.emitOpError(
|
||||
"return type must match the one of the attached value "
|
||||
"attribute: ")
|
||||
<< attrType.getRank() << " != " << resultType.getRank();
|
||||
}
|
||||
for (int dim = 0; dim < attrType.getRank(); ++dim) {
|
||||
if (attrType.getShape()[dim] != resultType.getShape()[dim]) {
|
||||
return op.emitOpError(
|
||||
"return type shape mismatches its attribute at dimension ")
|
||||
<< dim << ": " << attrType.getShape()[dim]
|
||||
<< " != " << resultType.getShape()[dim];
|
||||
}
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
static void buildAddOp(mlir::Builder *builder, mlir::OperationState &state,
|
||||
mlir::Value *lhs, mlir::Value *rhs) {
|
||||
state.addTypes(builder->getTensorType(builder->getF64Type()));
|
||||
state.addOperands({lhs, rhs});
|
||||
}
|
||||
|
||||
static void buildGenericCallOp(mlir::Builder *builder,
|
||||
mlir::OperationState &state, StringRef callee,
|
||||
ArrayRef<mlir::Value *> arguments) {
|
||||
// Generic call always returns an unranked Tensor initially.
|
||||
state.addTypes(builder->getTensorType(builder->getF64Type()));
|
||||
state.addOperands(arguments);
|
||||
state.addAttribute("callee", builder->getSymbolRefAttr(callee));
|
||||
}
|
||||
|
||||
static void buildMulOp(mlir::Builder *builder, mlir::OperationState &state,
|
||||
mlir::Value *lhs, mlir::Value *rhs) {
|
||||
state.addTypes(builder->getTensorType(builder->getF64Type()));
|
||||
state.addOperands({lhs, rhs});
|
||||
}
|
||||
|
||||
static mlir::LogicalResult verify(ReturnOp op) {
|
||||
// We know that the parent operation is a function, because of the 'HasParent'
|
||||
// trait attached to the operation definition.
|
||||
auto function = cast<FuncOp>(op.getParentOp());
|
||||
|
||||
/// ReturnOps can only have a single optional operand.
|
||||
if (op.getNumOperands() > 1)
|
||||
return op.emitOpError() << "expects at most 1 return operand";
|
||||
|
||||
// The operand number and types must match the function signature.
|
||||
const auto &results = function.getType().getResults();
|
||||
if (op.getNumOperands() != results.size())
|
||||
return op.emitOpError()
|
||||
<< "does not return the same number of values ("
|
||||
<< op.getNumOperands() << ") as the enclosing function ("
|
||||
<< results.size() << ")";
|
||||
|
||||
// If the operation does not have an input, we are done.
|
||||
if (!op.hasOperand())
|
||||
return mlir::success();
|
||||
|
||||
auto inputType = *op.operand_type_begin();
|
||||
auto resultType = results.front();
|
||||
|
||||
// Check that the result type of the function matches the operand type.
|
||||
if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() ||
|
||||
resultType.isa<mlir::UnrankedTensorType>())
|
||||
return mlir::success();
|
||||
|
||||
return op.emitError() << "type of return operand ("
|
||||
<< *op.operand_type_begin()
|
||||
<< ") doesn't match function result type ("
|
||||
<< results.front() << ")";
|
||||
}
|
||||
|
||||
static void buildTransposeOp(mlir::Builder *builder,
|
||||
mlir::OperationState &state, mlir::Value *value) {
|
||||
state.addTypes(builder->getTensorType(builder->getF64Type()));
|
||||
state.addOperands(value);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TableGen'd op method definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "toy/Ops.cpp.inc"
|
|
@ -25,30 +25,30 @@
|
|||
#include "toy/Dialect.h"
|
||||
|
||||
#include "mlir/Analysis/Verifier.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/ScopedHashTable.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include <numeric>
|
||||
|
||||
using namespace mlir::toy;
|
||||
using namespace toy;
|
||||
|
||||
using llvm::ArrayRef;
|
||||
using llvm::cast;
|
||||
using llvm::dyn_cast;
|
||||
using llvm::isa;
|
||||
using llvm::makeArrayRef;
|
||||
using llvm::ScopedHashTableScope;
|
||||
using llvm::SmallVector;
|
||||
using llvm::StringRef;
|
||||
using llvm::Twine;
|
||||
using std::make_unique;
|
||||
|
||||
namespace {
|
||||
|
||||
|
@ -57,56 +57,43 @@ namespace {
|
|||
/// This will emit operations that are specific to the Toy language, preserving
|
||||
/// the semantics of the language and (hopefully) allow to perform accurate
|
||||
/// analysis and transformation based on these high level semantics.
|
||||
///
|
||||
/// At this point we take advantage of the "raw" MLIR APIs to create operations
|
||||
/// that haven't been registered in any way with MLIR. These operations are
|
||||
/// unknown to MLIR, custom passes could operate by string-matching the name of
|
||||
/// these operations, but no other type checking or semantic is associated with
|
||||
/// them natively by MLIR.
|
||||
class MLIRGenImpl {
|
||||
public:
|
||||
MLIRGenImpl(mlir::MLIRContext &context) : context(context) {}
|
||||
MLIRGenImpl(mlir::MLIRContext &context) : builder(&context) {}
|
||||
|
||||
/// Public API: convert the AST for a Toy module (source file) to an MLIR
|
||||
/// Module.
|
||||
mlir::OwningModuleRef mlirGen(ModuleAST &moduleAST) {
|
||||
/// Module operation.
|
||||
mlir::ModuleOp mlirGen(ModuleAST &moduleAST) {
|
||||
// We create an empty MLIR module and codegen functions one at a time and
|
||||
// add them to the module.
|
||||
theModule = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||
theModule = mlir::ModuleOp::create(builder.getUnknownLoc());
|
||||
|
||||
for (FunctionAST &F : moduleAST) {
|
||||
auto func = mlirGen(F);
|
||||
if (!func)
|
||||
return nullptr;
|
||||
theModule->push_back(func);
|
||||
theModule.push_back(func);
|
||||
}
|
||||
|
||||
// FIXME: (in the next chapter...) without registering a dialect in MLIR,
|
||||
// this won't do much, but it should at least check some structural
|
||||
// properties.
|
||||
if (failed(mlir::verify(*theModule))) {
|
||||
emitError(mlir::UnknownLoc::get(&context), "module verification error");
|
||||
// Verify the module after we have finished constructing it, this will check
|
||||
// the structural properties of the IR and invoke any specific verifiers we
|
||||
// have on the Toy operations.
|
||||
if (failed(mlir::verify(theModule))) {
|
||||
theModule.emitError("Module verification error");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return std::move(theModule);
|
||||
return theModule;
|
||||
}
|
||||
|
||||
private:
|
||||
/// In MLIR (like in LLVM) a "context" object holds the memory allocation and
|
||||
/// the ownership of many internal structure of the IR and provide a level
|
||||
/// of "uniquing" across multiple modules (types for instance).
|
||||
mlir::MLIRContext &context;
|
||||
/// A "module" matches a Toy source file: containing a list of functions.
|
||||
mlir::ModuleOp theModule;
|
||||
|
||||
/// A "module" matches a source file: it contains a list of functions.
|
||||
mlir::OwningModuleRef theModule;
|
||||
|
||||
/// The builder is a helper class to create IR inside a function. It is
|
||||
/// re-initialized every time we enter a function and kept around as a
|
||||
/// convenience for emitting individual operations.
|
||||
/// The builder is stateful, in particular it keeps an "insertion point":
|
||||
/// this is where the next operations will be introduced.
|
||||
std::unique_ptr<mlir::OpBuilder> builder;
|
||||
/// The builder is a helper class to create IR inside a function. The builder
|
||||
/// is stateful, in particular it keeeps an "insertion point": this is where
|
||||
/// the next operations will be introduced.
|
||||
mlir::OpBuilder builder;
|
||||
|
||||
/// The symbol table maps a variable name to a value in the current scope.
|
||||
/// Entering a function creates a new scope, and the function arguments are
|
||||
|
@ -116,37 +103,35 @@ private:
|
|||
|
||||
/// Helper conversion for a Toy AST location to an MLIR location.
|
||||
mlir::Location loc(Location loc) {
|
||||
return mlir::FileLineColLoc::get(mlir::Identifier::get(*loc.file, &context),
|
||||
loc.line, loc.col, &context);
|
||||
return builder.getFileLineColLoc(builder.getIdentifier(*loc.file), loc.line,
|
||||
loc.col);
|
||||
}
|
||||
|
||||
/// Declare a variable in the current scope, return true if the variable
|
||||
/// Declare a variable in the current scope, return success if the variable
|
||||
/// wasn't declared yet.
|
||||
bool declare(llvm::StringRef var, mlir::Value *value) {
|
||||
if (symbolTable.count(var)) {
|
||||
return false;
|
||||
}
|
||||
mlir::LogicalResult declare(llvm::StringRef var, mlir::Value *value) {
|
||||
if (symbolTable.count(var))
|
||||
return mlir::failure();
|
||||
symbolTable.insert(var, value);
|
||||
return true;
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
/// Create the prototype for an MLIR function with as many arguments as the
|
||||
/// provided Toy AST prototype.
|
||||
mlir::FuncOp mlirGen(PrototypeAST &proto) {
|
||||
auto location = loc(proto.loc());
|
||||
|
||||
// This is a generic function, the return type will be inferred later.
|
||||
llvm::SmallVector<mlir::Type, 4> ret_types;
|
||||
// Arguments type is uniformly a generic array.
|
||||
// Arguments type are uniformly unranked tensors.
|
||||
llvm::SmallVector<mlir::Type, 4> arg_types(proto.getArgs().size(),
|
||||
getType(VarType{}));
|
||||
auto func_type = mlir::FunctionType::get(arg_types, ret_types, &context);
|
||||
auto function = mlir::FuncOp::create(loc(proto.loc()), proto.getName(),
|
||||
func_type, /* attrs = */ {});
|
||||
auto func_type = builder.getFunctionType(arg_types, llvm::None);
|
||||
auto function = mlir::FuncOp::create(location, proto.getName(), func_type);
|
||||
|
||||
// Mark the function as generic: it'll require type specialization for every
|
||||
// call site.
|
||||
if (function.getNumArguments())
|
||||
function.setAttr("toy.generic", mlir::BoolAttr::get(true, &context));
|
||||
|
||||
function.setAttr("toy.generic", builder.getUnitAttr());
|
||||
return function;
|
||||
}
|
||||
|
||||
|
@ -165,29 +150,39 @@ private:
|
|||
// argument list as the function itself.
|
||||
auto &entryBlock = *function.addEntryBlock();
|
||||
auto &protoArgs = funcAST.getProto()->getArgs();
|
||||
|
||||
// Declare all the function arguments in the symbol table.
|
||||
for (const auto &name_value :
|
||||
llvm::zip(protoArgs, entryBlock.getArguments())) {
|
||||
declare(std::get<0>(name_value)->getName(), std::get<1>(name_value));
|
||||
if (failed(declare(std::get<0>(name_value)->getName(),
|
||||
std::get<1>(name_value))))
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Create a builder for the function, it will be used throughout the codegen
|
||||
// to create operations in this function.
|
||||
builder = std::make_unique<mlir::OpBuilder>(function.getBody());
|
||||
// Set the insertion point in the builder to the beginning of the function
|
||||
// body, it will be used throughout the codegen to create operations in this
|
||||
// function.
|
||||
builder.setInsertionPointToStart(&entryBlock);
|
||||
|
||||
// Emit the body of the function.
|
||||
if (!mlirGen(*funcAST.getBody())) {
|
||||
if (mlir::failed(mlirGen(*funcAST.getBody()))) {
|
||||
function.erase();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Implicitly return void if no return statement was emited.
|
||||
// Implicitly return void if no return statement was emitted.
|
||||
// FIXME: we may fix the parser instead to always return the last expression
|
||||
// (this would possibly help the REPL case later)
|
||||
if (function.getBlocks().back().back().getName().getStringRef() !=
|
||||
"toy.return") {
|
||||
ReturnExprAST fakeRet(funcAST.getProto()->loc(), llvm::None);
|
||||
mlirGen(fakeRet);
|
||||
ReturnOp returnOp;
|
||||
if (!entryBlock.empty())
|
||||
returnOp = dyn_cast<ReturnOp>(entryBlock.back());
|
||||
if (!returnOp) {
|
||||
builder.create<ReturnOp>(loc(funcAST.getProto()->loc()));
|
||||
} else if (returnOp.hasOperand()) {
|
||||
// Otherwise, if this return operation has an operand then add a result to
|
||||
// the function.
|
||||
function.setType(builder.getFunctionType(function.getType().getInputs(),
|
||||
getType(VarType{})));
|
||||
}
|
||||
|
||||
return function;
|
||||
|
@ -206,11 +201,11 @@ private:
|
|||
// and the result value is returned. If an error occurs we get a nullptr
|
||||
// and propagate.
|
||||
//
|
||||
mlir::Value *L = mlirGen(*binop.getLHS());
|
||||
if (!L)
|
||||
mlir::Value *lhs = mlirGen(*binop.getLHS());
|
||||
if (!lhs)
|
||||
return nullptr;
|
||||
mlir::Value *R = mlirGen(*binop.getRHS());
|
||||
if (!R)
|
||||
mlir::Value *rhs = mlirGen(*binop.getRHS());
|
||||
if (!rhs)
|
||||
return nullptr;
|
||||
auto location = loc(binop.loc());
|
||||
|
||||
|
@ -218,123 +213,112 @@ private:
|
|||
// support '+' and '*'.
|
||||
switch (binop.getOp()) {
|
||||
case '+':
|
||||
return builder->create<AddOp>(location, L, R).getResult();
|
||||
break;
|
||||
return builder.create<AddOp>(location, lhs, rhs);
|
||||
case '*':
|
||||
return builder->create<MulOp>(location, L, R).getResult();
|
||||
default:
|
||||
emitError(location, "error: invalid binary operator '")
|
||||
<< binop.getOp() << "'";
|
||||
return nullptr;
|
||||
return builder.create<MulOp>(location, lhs, rhs);
|
||||
}
|
||||
|
||||
emitError(location, "invalid binary operator '") << binop.getOp() << "'";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// This is a reference to a variable in an expression. The variable is
|
||||
// expected to have been declared and so should have a value in the symbol
|
||||
// table, otherwise emit an error and return nullptr.
|
||||
/// This is a reference to a variable in an expression. The variable is
|
||||
/// expected to have been declared and so should have a value in the symbol
|
||||
/// table, otherwise emit an error and return nullptr.
|
||||
mlir::Value *mlirGen(VariableExprAST &expr) {
|
||||
if (symbolTable.count(expr.getName()))
|
||||
return symbolTable.lookup(expr.getName());
|
||||
emitError(loc(expr.loc()), "error: unknown variable '")
|
||||
if (auto *variable = symbolTable.lookup(expr.getName()))
|
||||
return variable;
|
||||
|
||||
emitError(loc(expr.loc()), "Error: unknown variable '")
|
||||
<< expr.getName() << "'";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Emit a return operation, return true on success.
|
||||
bool mlirGen(ReturnExprAST &ret) {
|
||||
/// Emit a return operation. This will return failure if any generation fails.
|
||||
mlir::LogicalResult mlirGen(ReturnExprAST &ret) {
|
||||
auto location = loc(ret.loc());
|
||||
// `return` takes an optional expression, we need to account for it here.
|
||||
if (!ret.getExpr().hasValue()) {
|
||||
builder->create<ReturnOp>(location);
|
||||
return true;
|
||||
|
||||
// 'return' takes an optional expression, handle that case here.
|
||||
mlir::Value *expr = nullptr;
|
||||
if (ret.getExpr().hasValue()) {
|
||||
if (!(expr = mlirGen(*ret.getExpr().getValue())))
|
||||
return mlir::failure();
|
||||
}
|
||||
auto *expr = mlirGen(*ret.getExpr().getValue());
|
||||
if (!expr)
|
||||
return false;
|
||||
builder->create<ReturnOp>(location, expr);
|
||||
return true;
|
||||
|
||||
// Otherwise, this return operation has zero operands.
|
||||
builder.create<ReturnOp>(location, expr ? makeArrayRef(expr)
|
||||
: ArrayRef<mlir::Value *>());
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
// Emit a literal/constant array. It will be emitted as a flattened array of
|
||||
// data in an Attribute attached to a `toy.constant` operation.
|
||||
// See documentation on [Attributes](LangRef.md#attributes) for more details.
|
||||
// Here is an excerpt:
|
||||
//
|
||||
// Attributes are the mechanism for specifying constant data in MLIR in
|
||||
// places where a variable is never allowed [...]. They consist of a name
|
||||
// and a [concrete attribute value](#attribute-values). It is possible to
|
||||
// attach attributes to operations, functions, and function arguments. The
|
||||
// set of expected attributes, their structure, and their interpretation
|
||||
// are all contextually dependent on what they are attached to.
|
||||
//
|
||||
// Example, the source level statement:
|
||||
// var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
|
||||
// will be converted to:
|
||||
// %0 = "toy.constant"() {value: dense<tensor<2x3xf64>,
|
||||
// [[1.000000e+00, 2.000000e+00, 3.000000e+00],
|
||||
// [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> memref<2x3xf64>
|
||||
//
|
||||
/// Emit a literal/constant array. It will be emitted as a flattened array of
|
||||
/// data in an Attribute attached to a `toy.constant` operation.
|
||||
/// See documentation on [Attributes](LangRef.md#attributes) for more details.
|
||||
/// Here is an excerpt:
|
||||
///
|
||||
/// Attributes are the mechanism for specifying constant data in MLIR in
|
||||
/// places where a variable is never allowed [...]. They consist of a name
|
||||
/// and a concrete attribute value. The set of expected attributes, their
|
||||
/// structure, and their interpretation are all contextually dependent on
|
||||
/// what they are attached to.
|
||||
///
|
||||
/// Example, the source level statement:
|
||||
/// var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
|
||||
/// will be converted to:
|
||||
/// %0 = "toy.constant"() {value: dense<tensor<2x3xf64>,
|
||||
/// [[1.000000e+00, 2.000000e+00, 3.000000e+00],
|
||||
/// [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> tensor<2x3xf64>
|
||||
///
|
||||
mlir::Value *mlirGen(LiteralExprAST &lit) {
|
||||
auto location = loc(lit.loc());
|
||||
// The attribute is a vector with an attribute per element (number) in the
|
||||
// array, see `collectData()` below for more details.
|
||||
std::vector<mlir::Attribute> data;
|
||||
auto type = getType(lit.getDims());
|
||||
|
||||
// The attribute is a vector with a floating point value per element
|
||||
// (number) in the array, see `collectData()` below for more details.
|
||||
std::vector<double> data;
|
||||
data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1,
|
||||
std::multiplies<int>()));
|
||||
collectData(lit, data);
|
||||
|
||||
// FIXME: using a tensor type is a HACK here.
|
||||
// Can we do differently without registering a dialect? Using a string blob?
|
||||
mlir::Type elementType = mlir::FloatType::getF64(&context);
|
||||
auto dataType = builder->getTensorType(lit.getDims(), elementType);
|
||||
// The type of this attribute is tensor of 64-bit floating-point with the
|
||||
// shape of the literal.
|
||||
mlir::Type elementType = builder.getF64Type();
|
||||
auto dataType = builder.getTensorType(lit.getDims(), elementType);
|
||||
|
||||
// This is the actual attribute that actually hold the list of values for
|
||||
// this array literal.
|
||||
auto dataAttribute = builder->getDenseElementsAttr(dataType, data)
|
||||
.cast<mlir::DenseElementsAttr>();
|
||||
// This is the actual attribute that holds the list of values for this
|
||||
// tensor literal.
|
||||
auto dataAttribute =
|
||||
mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(data));
|
||||
|
||||
// Build the MLIR op `toy.constant`, only boilerplate below.
|
||||
return builder->create<ConstantOp>(location, lit.getDims(), dataAttribute)
|
||||
.getResult();
|
||||
// Build the MLIR op `toy.constant`.
|
||||
return builder.create<ConstantOp>(loc(lit.loc()), type, dataAttribute);
|
||||
}
|
||||
|
||||
// Recursive helper function to accumulate the data that compose an array
|
||||
// literal. It flattens the nested structure in the supplied vector. For
|
||||
// example with this array:
|
||||
// [[1, 2], [3, 4]]
|
||||
// we will generate:
|
||||
// [ 1, 2, 3, 4 ]
|
||||
// Individual numbers are wrapped in a light wrapper `mlir::FloatAttr`.
|
||||
// Attributes are the way MLIR attaches constant to operations and functions.
|
||||
void collectData(ExprAST &expr, std::vector<mlir::Attribute> &data) {
|
||||
/// Recursive helper function to accumulate the data that compose an array
|
||||
/// literal. It flattens the nested structure in the supplied vector. For
|
||||
/// example with this array:
|
||||
/// [[1, 2], [3, 4]]
|
||||
/// we will generate:
|
||||
/// [ 1, 2, 3, 4 ]
|
||||
/// Individual numbers are represented as doubles.
|
||||
/// Attributes are the way MLIR attaches constant to operations.
|
||||
void collectData(ExprAST &expr, std::vector<double> &data) {
|
||||
if (auto *lit = dyn_cast<LiteralExprAST>(&expr)) {
|
||||
for (auto &value : lit->getValues())
|
||||
collectData(*value, data);
|
||||
return;
|
||||
}
|
||||
|
||||
assert(isa<NumberExprAST>(expr) && "expected literal or number expr");
|
||||
mlir::Type elementType = mlir::FloatType::getF64(&context);
|
||||
auto attr = mlir::FloatAttr::getChecked(
|
||||
elementType, cast<NumberExprAST>(expr).getValue(), loc(expr.loc()));
|
||||
data.push_back(attr);
|
||||
data.push_back(cast<NumberExprAST>(expr).getValue());
|
||||
}
|
||||
|
||||
// Emit a call expression. It emits specific operations for the `transpose`
|
||||
// builtin. Other identifiers are assumed to be user-defined functions.
|
||||
/// Emit a call expression. It emits specific operations for the `transpose`
|
||||
/// builtin. Other identifiers are assumed to be user-defined functions.
|
||||
mlir::Value *mlirGen(CallExprAST &call) {
|
||||
llvm::StringRef callee = call.getCallee();
|
||||
auto location = loc(call.loc());
|
||||
std::string callee = call.getCallee();
|
||||
if (callee == "transpose") {
|
||||
if (call.getArgs().size() != 1) {
|
||||
emitError(location, "MLIR codegen encountered an error: toy.transpose "
|
||||
"does not accept multiple arguments");
|
||||
return nullptr;
|
||||
}
|
||||
mlir::Value *arg = mlirGen(*call.getArgs()[0]);
|
||||
return builder->create<TransposeOp>(location, arg).getResult();
|
||||
}
|
||||
|
||||
// Codegen the operands first
|
||||
// Codegen the operands first.
|
||||
SmallVector<mlir::Value *, 4> operands;
|
||||
for (auto &expr : call.getArgs()) {
|
||||
auto *arg = mlirGen(*expr);
|
||||
|
@ -342,34 +326,41 @@ private:
|
|||
return nullptr;
|
||||
operands.push_back(arg);
|
||||
}
|
||||
// Calls to user-defined function are mapped to a custom call that takes
|
||||
// the callee name as an attribute.
|
||||
return builder->create<GenericCallOp>(location, call.getCallee(), operands)
|
||||
.getResult();
|
||||
|
||||
// Builting calls have their custom operation, meaning this is a
|
||||
// straightforward emission.
|
||||
if (callee == "transpose") {
|
||||
if (call.getArgs().size() != 1) {
|
||||
emitError(location, "MLIR codegen encountered an error: toy.transpose "
|
||||
"does not accept multiple arguments");
|
||||
return nullptr;
|
||||
}
|
||||
return builder.create<TransposeOp>(location, operands[0]);
|
||||
}
|
||||
|
||||
// Otherwise this is a call to a user-defined function. Calls to ser-defined
|
||||
// functions are mapped to a custom call that takes the callee name as an
|
||||
// attribute.
|
||||
return builder.create<GenericCallOp>(location, callee, operands);
|
||||
}
|
||||
|
||||
// Emit a call expression. It emits specific operations for two builtins:
|
||||
// transpose(x) and print(x). Other identifiers are assumed to be user-defined
|
||||
// functions. Return false on failure.
|
||||
bool mlirGen(PrintExprAST &call) {
|
||||
/// Emit a print expression. It emits specific operations for two builtins:
|
||||
/// transpose(x) and print(x).
|
||||
mlir::LogicalResult mlirGen(PrintExprAST &call) {
|
||||
auto *arg = mlirGen(*call.getArg());
|
||||
if (!arg)
|
||||
return false;
|
||||
auto location = loc(call.loc());
|
||||
builder->create<PrintOp>(location, arg);
|
||||
return true;
|
||||
return mlir::failure();
|
||||
|
||||
builder.create<PrintOp>(loc(call.loc()), arg);
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
// Emit a constant for a single number (FIXME: semantic? broadcast?)
|
||||
/// Emit a constant for a single number (FIXME: semantic? broadcast?)
|
||||
mlir::Value *mlirGen(NumberExprAST &num) {
|
||||
auto location = loc(num.loc());
|
||||
mlir::Type elementType = mlir::FloatType::getF64(&context);
|
||||
auto attr = mlir::FloatAttr::getChecked(elementType, num.getValue(),
|
||||
loc(num.loc()));
|
||||
return builder->create<ConstantOp>(location, attr).getResult();
|
||||
return builder.create<ConstantOp>(loc(num.loc()), num.getValue());
|
||||
}
|
||||
|
||||
// Dispatch codegen for the right expression subclass using RTTI.
|
||||
/// Dispatch codegen for the right expression subclass using RTTI.
|
||||
mlir::Value *mlirGen(ExprAST &expr) {
|
||||
switch (expr.getKind()) {
|
||||
case toy::ExprAST::Expr_BinOp:
|
||||
|
@ -390,77 +381,75 @@ private:
|
|||
}
|
||||
}
|
||||
|
||||
// Handle a variable declaration, we'll codegen the expression that forms the
|
||||
// initializer and record the value in the symbol table before returning it.
|
||||
// Future expressions will be able to reference this variable through symbol
|
||||
// table lookup.
|
||||
/// Handle a variable declaration, we'll codegen the expression that forms the
|
||||
/// initializer and record the value in the symbol table before returning it.
|
||||
/// Future expressions will be able to reference this variable through symbol
|
||||
/// table lookup.
|
||||
mlir::Value *mlirGen(VarDeclExprAST &vardecl) {
|
||||
mlir::Value *value = nullptr;
|
||||
auto location = loc(vardecl.loc());
|
||||
if (auto init = vardecl.getInitVal()) {
|
||||
value = mlirGen(*init);
|
||||
if (!value)
|
||||
return nullptr;
|
||||
// We have the initializer value, but in case the variable was declared
|
||||
// with specific shape, we emit a "reshape" operation. It will get
|
||||
// optimized out later as needed.
|
||||
if (!vardecl.getType().shape.empty()) {
|
||||
value = builder
|
||||
->create<ReshapeOp>(
|
||||
location, value,
|
||||
getType(vardecl.getType()).cast<ToyArrayType>())
|
||||
.getResult();
|
||||
}
|
||||
} else {
|
||||
auto init = vardecl.getInitVal();
|
||||
if (!init) {
|
||||
emitError(loc(vardecl.loc()),
|
||||
"missing initializer in variable declaration");
|
||||
"Missing initializer in variable declaration");
|
||||
return nullptr;
|
||||
}
|
||||
// Register the value in the symbol table
|
||||
declare(vardecl.getName(), value);
|
||||
|
||||
mlir::Value *value = mlirGen(*init);
|
||||
if (!value)
|
||||
return nullptr;
|
||||
|
||||
// We have the initializer value, but in case the variable was declared
|
||||
// with specific shape, we emit a "reshape" operation. It will get
|
||||
// optimized out later as needed.
|
||||
if (!vardecl.getType().shape.empty()) {
|
||||
value = builder.create<ReshapeOp>(loc(vardecl.loc()),
|
||||
getType(vardecl.getType()), value);
|
||||
}
|
||||
|
||||
// Register the value in the symbol table.
|
||||
if (failed(declare(vardecl.getName(), value)))
|
||||
return nullptr;
|
||||
return value;
|
||||
}
|
||||
|
||||
/// Codegen a list of expression, return false if one of them hit an error.
|
||||
bool mlirGen(ExprASTList &blockAST) {
|
||||
ScopedHashTableScope<llvm::StringRef, mlir::Value *> var_scope(symbolTable);
|
||||
/// Codegen a list of expression, return failure if one of them hit an error.
|
||||
mlir::LogicalResult mlirGen(ExprASTList &blockAST) {
|
||||
ScopedHashTableScope<StringRef, mlir::Value *> var_scope(symbolTable);
|
||||
for (auto &expr : blockAST) {
|
||||
// Specific handling for variable declarations, return statement, and
|
||||
// print. These can only appear in block list and not in nested
|
||||
// expressions.
|
||||
if (auto *vardecl = dyn_cast<VarDeclExprAST>(expr.get())) {
|
||||
if (!mlirGen(*vardecl))
|
||||
return false;
|
||||
return mlir::failure();
|
||||
continue;
|
||||
}
|
||||
if (auto *ret = dyn_cast<ReturnExprAST>(expr.get())) {
|
||||
if (!mlirGen(*ret))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
if (auto *ret = dyn_cast<ReturnExprAST>(expr.get()))
|
||||
return mlirGen(*ret);
|
||||
if (auto *print = dyn_cast<PrintExprAST>(expr.get())) {
|
||||
if (!mlirGen(*print))
|
||||
return false;
|
||||
if (mlir::failed(mlirGen(*print)))
|
||||
return mlir::success();
|
||||
continue;
|
||||
}
|
||||
|
||||
// Generic expression dispatch codegen.
|
||||
if (!mlirGen(*expr))
|
||||
return false;
|
||||
return mlir::failure();
|
||||
}
|
||||
return true;
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
/// Build a type from a list of shape dimensions. Types are `array` followed
|
||||
/// by an optional dimension list, example: array<2, 2>
|
||||
/// They are wrapped in a `toy` dialect (see next chapter) and get printed:
|
||||
/// !toy.array<2, 2>
|
||||
template <typename T> mlir::Type getType(T shape) {
|
||||
SmallVector<int64_t, 8> shape64(shape.begin(), shape.end());
|
||||
return ToyArrayType::get(&context, shape64);
|
||||
/// Build a tensor type from a list of shape dimensions.
|
||||
mlir::Type getType(ArrayRef<int64_t> shape) {
|
||||
// If the shape is empty, then this type is unranked.
|
||||
if (shape.empty())
|
||||
return builder.getTensorType(builder.getF64Type());
|
||||
|
||||
// Otherwise, we use the given shape.
|
||||
return builder.getTensorType(shape, builder.getF64Type());
|
||||
}
|
||||
|
||||
/// Build an MLIR type from a Toy AST variable type
|
||||
/// (forward to the generic getType(T) above).
|
||||
/// Build an MLIR type from a Toy AST variable type (forward to the generic
|
||||
/// getType above).
|
||||
mlir::Type getType(const VarType &type) { return getType(type.shape); }
|
||||
};
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
//===- ShapeInferencePass.cpp - Toy Shape Inference / Func Specialization -===//
|
||||
//===- ShapeInferencePass.cpp - Shape Inference ---------------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
|
@ -15,22 +15,14 @@
|
|||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements a Module level pass performing interprocedural
|
||||
// This file implements a Function level pass performing interprocedural
|
||||
// propagation of array shapes through function specialization.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "toy/Dialect.h"
|
||||
|
||||
#include "mlir/Analysis/Verifier.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/Support/STLExtras.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "toy/Dialect.h"
|
||||
#include "toy/Passes.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringSet.h"
|
||||
|
@ -39,48 +31,26 @@
|
|||
#include "llvm/Support/raw_ostream.h"
|
||||
#include <algorithm>
|
||||
|
||||
#define DEBUG_TYPE "toy-shape-inference"
|
||||
#define DEBUG_TYPE "shape-inference"
|
||||
|
||||
using namespace toy;
|
||||
using llvm::MutableArrayRef;
|
||||
using llvm::raw_ostream;
|
||||
using llvm::SmallVector;
|
||||
using llvm::SmallVectorImpl;
|
||||
using llvm::StringRef;
|
||||
using llvm::Twine;
|
||||
|
||||
/// Create a mangled name for function specialization. We will simply append the
|
||||
/// shape of the arguments to the function name. For example, calling
|
||||
///
|
||||
/// "toy.generic_call"(%1, %3) {callee: "foo"}
|
||||
/// : (!toy.array<2, 3>, !toy.array<2, 3>) -> !toy<"array">
|
||||
///
|
||||
/// would be mangled foo_2x3_2x3. This mangling isn't robust as the user could
|
||||
/// have provided a function with a similar name, but we will claim this as a
|
||||
/// feature: this allows the user to provide custom specializations!
|
||||
static std::string mangle(StringRef funcName,
|
||||
MutableArrayRef<mlir::OpOperand> operands) {
|
||||
std::string mangledName;
|
||||
mangledName.reserve(funcName.size() + operands.size() * 6);
|
||||
mangledName = funcName;
|
||||
for (auto &operand : operands) {
|
||||
auto arrayTy = operand.get()->getType().cast<ToyArrayType>();
|
||||
mangledName += "_";
|
||||
mlir::interleave(
|
||||
arrayTy.getShape(),
|
||||
[&](int64_t dim) { mangledName += Twine(dim).str(); },
|
||||
[&]() { mangledName += "x"; });
|
||||
}
|
||||
return mangledName;
|
||||
}
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
|
||||
/// The ShapeInferencePass is a ModulePass: it will run on the Module as a
|
||||
/// whole. MLIR also supports FunctionPass which are restricted to modify a
|
||||
/// single function at a time. This pass couldn't be a function pass due the
|
||||
/// nature of its interprocedural transformations.
|
||||
// clang-format off
|
||||
#include "toy/ShapeInferenceOpInterfaces.h.inc"
|
||||
#include "toy/ShapeInferenceOpInterfaces.cpp.inc"
|
||||
|
||||
/// The ShapeInferencePass is a FunctionPass that performs intra-procedural
|
||||
/// shape inference.
|
||||
///
|
||||
/// The algorithm has two levels, first intra-procedurally:
|
||||
/// Algorithm:
|
||||
///
|
||||
/// 1) Build a worklist containing all the operations that are returning
|
||||
/// a generic Toy array: these are the operations that need shape
|
||||
|
@ -94,132 +64,25 @@ namespace {
|
|||
/// 3) If the worklist is empty, the algorithm succeeded and we infer the
|
||||
/// return type for the function from the return operation.
|
||||
///
|
||||
/// There is a twist though: when a call to a generic function is encountered,
|
||||
/// shape inference requires the return type of the callee to be inferred first.
|
||||
/// At this point we need to run specialize the callee by cloning it. Here is
|
||||
/// the inter-procedural flow:
|
||||
///
|
||||
/// 1) Keep a worklist of function to process. Start with function "main".
|
||||
/// 2) While the worklist isn't empty:
|
||||
/// a) Take the last inserted function in the worklist.
|
||||
/// b) Run the intra-procedural shape inference on this function.
|
||||
/// c) If the intra-procedural shape inference can't complete, it returns
|
||||
/// a Function that needs to be inferred first. In this case, queue this
|
||||
/// new function and continue. Otherwise the inference succeeded and we
|
||||
/// can pop from the queue.
|
||||
///
|
||||
class ShapeInferencePass : public mlir::ModulePass<ShapeInferencePass> {
|
||||
class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
|
||||
public:
|
||||
// One entry in the inter-procedural worklist. It keeps track of the
|
||||
// function to process, the mangled name for this specialization, and the
|
||||
// types of the arguments on which to specialize.
|
||||
struct FunctionToSpecialize {
|
||||
mlir::FuncOp function;
|
||||
std::string mangledName;
|
||||
SmallVector<mlir::Type, 4> argumentsType;
|
||||
};
|
||||
|
||||
void runOnModule() override {
|
||||
auto module = getModule();
|
||||
auto main = module.lookupSymbol<mlir::FuncOp>("main");
|
||||
if (!main) {
|
||||
emitError(mlir::UnknownLoc::get(module.getContext()),
|
||||
"shape inference failed: can't find a main function\n");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
/// Inter-procedural loop, initialize with `main` and iterate until we
|
||||
/// successfully infer the full reachable call-graph from main.
|
||||
SmallVector<FunctionToSpecialize, 8> worklist;
|
||||
worklist.push_back({main, "", {}});
|
||||
while (!worklist.empty()) {
|
||||
if (failed(specialize(worklist)))
|
||||
return;
|
||||
}
|
||||
|
||||
// Delete any generic function left
|
||||
// FIXME: we may want this as a separate pass.
|
||||
for (mlir::FuncOp function :
|
||||
llvm::make_early_inc_range(module.getOps<mlir::FuncOp>())) {
|
||||
if (auto genericAttr =
|
||||
function.getAttrOfType<mlir::BoolAttr>("toy.generic")) {
|
||||
if (genericAttr.getValue())
|
||||
function.erase();
|
||||
}
|
||||
bool returnsGenericArray(Operation *op) {
|
||||
if (op->getNumResults() == 1) {
|
||||
if (!op->getResult(0)->getType().isa<ShapedType>())
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Run inference on a function. If a mangledName is provided, we need to
|
||||
/// specialize the function: to this end clone it first.
|
||||
mlir::LogicalResult
|
||||
specialize(SmallVectorImpl<FunctionToSpecialize> &funcWorklist) {
|
||||
FunctionToSpecialize &functionToSpecialize = funcWorklist.back();
|
||||
mlir::FuncOp f = functionToSpecialize.function;
|
||||
|
||||
// Check if cloning for specialization is needed (usually anything but main)
|
||||
// We will create a new function with the concrete types for the parameters
|
||||
// and clone the body into it.
|
||||
if (!functionToSpecialize.mangledName.empty()) {
|
||||
if (getModule().lookupSymbol<mlir::FuncOp>(
|
||||
functionToSpecialize.mangledName)) {
|
||||
funcWorklist.pop_back();
|
||||
// Function already specialized, move on.
|
||||
return mlir::success();
|
||||
}
|
||||
// Create a new function with a generic array return type, it will be
|
||||
// updated when the inference for the function body completes.
|
||||
auto type = mlir::FunctionType::get(functionToSpecialize.argumentsType,
|
||||
{ToyArrayType::get(&getContext())},
|
||||
&getContext());
|
||||
auto newFunction =
|
||||
mlir::FuncOp::create(f.getLoc(), functionToSpecialize.mangledName,
|
||||
type, f.getDialectAttrs());
|
||||
getModule().push_back(newFunction);
|
||||
|
||||
// Clone the function body
|
||||
mlir::BlockAndValueMapping mapper;
|
||||
f.cloneInto(newFunction, mapper);
|
||||
LLVM_DEBUG({
|
||||
llvm::dbgs() << "====== Cloned : \n";
|
||||
f.dump();
|
||||
llvm::dbgs() << "====== Into : \n";
|
||||
newFunction.dump();
|
||||
});
|
||||
f = newFunction;
|
||||
f.setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext()));
|
||||
// Remap the entry-block arguments
|
||||
// FIXME: this seems like a bug in `cloneInto()` above?
|
||||
auto &entryBlock = f.getBlocks().front();
|
||||
int blockArgSize = entryBlock.getArguments().size();
|
||||
assert(blockArgSize == static_cast<int>(f.getType().getInputs().size()));
|
||||
entryBlock.addArguments(f.getType().getInputs());
|
||||
auto argList = entryBlock.getArguments();
|
||||
for (int argNum = 0; argNum < blockArgSize; ++argNum) {
|
||||
argList[0]->replaceAllUsesWith(argList[blockArgSize]);
|
||||
entryBlock.eraseArgument(0);
|
||||
}
|
||||
assert(succeeded(mlir::verify(f)));
|
||||
}
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "Run shape inference on : '" << f.getName() << "'\n");
|
||||
|
||||
auto *toyDialect = getContext().getRegisteredDialect("toy");
|
||||
if (!toyDialect) {
|
||||
emitError(mlir::UnknownLoc::get(&getContext()),
|
||||
"Toy dialect is not registered");
|
||||
signalPassFailure();
|
||||
return mlir::failure();
|
||||
}
|
||||
void runOnFunction() override {
|
||||
auto f = getFunction();
|
||||
|
||||
// Populate the worklist with the operations that need shape inference:
|
||||
// these are the Toy operations that return a generic array.
|
||||
// these are operations that return a generic array.
|
||||
llvm::SmallPtrSet<mlir::Operation *, 16> opWorklist;
|
||||
f.walk([&](mlir::Operation *op) {
|
||||
if (op->getDialect() == toyDialect) {
|
||||
if (op->getNumResults() == 1 &&
|
||||
op->getResult(0)->getType().cast<ToyArrayType>().isGeneric())
|
||||
opWorklist.insert(op);
|
||||
if (returnsGenericArray(op)) {
|
||||
opWorklist.insert(op);
|
||||
}
|
||||
});
|
||||
|
||||
|
@ -228,154 +91,31 @@ public:
|
|||
while (!opWorklist.empty()) {
|
||||
// Find the next operation ready for inference, that is an operation
|
||||
// with all operands already resolved (non-generic).
|
||||
auto nextop = llvm::find_if(opWorklist, [](mlir::Operation *op) {
|
||||
return llvm::all_of(op->getOperandTypes(), [](mlir::Type ty) {
|
||||
return !ty.cast<ToyArrayType>().isGeneric();
|
||||
});
|
||||
auto nextop = llvm::find_if(opWorklist, [this](Operation *op) {
|
||||
return this->returnsGenericArray(op);
|
||||
});
|
||||
|
||||
if (nextop == opWorklist.end())
|
||||
break; // failure: no operations can be inferred.
|
||||
|
||||
mlir::Operation *op = *nextop;
|
||||
Operation *op = *nextop;
|
||||
opWorklist.erase(op);
|
||||
LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n");
|
||||
|
||||
// The add operation is trivial: propagate the input type as is.
|
||||
if (auto addOp = llvm::dyn_cast<AddOp>(op)) {
|
||||
op->getResult(0)->setType(op->getOperand(0)->getType());
|
||||
continue;
|
||||
}
|
||||
|
||||
// Transpose is easy: just invert the dimensions.
|
||||
if (auto transpose = llvm::dyn_cast<TransposeOp>(op)) {
|
||||
SmallVector<int64_t, 2> dims;
|
||||
auto arrayTy = transpose.getOperand()->getType().cast<ToyArrayType>();
|
||||
dims.insert(dims.end(), arrayTy.getShape().begin(),
|
||||
arrayTy.getShape().end());
|
||||
transpose.getResult()->setType(ToyArrayType::get(&getContext(), dims));
|
||||
continue;
|
||||
}
|
||||
|
||||
// Multiplication is a bit trickier, handle rank 1 as dot product and rank
|
||||
// 2 as matrix multiplications.
|
||||
// We need to be careful about rank mismatch here: the verifier could
|
||||
// catch it but shape inference earlier in the pass could generate an
|
||||
// invalid IR (from an invalid Toy input of course) and we wouldn't want
|
||||
// to crash here.
|
||||
if (auto mulOp = llvm::dyn_cast<MulOp>(op)) {
|
||||
auto lhs = mulOp.getLHS()->getType().cast<ToyArrayType>();
|
||||
auto rhs = mulOp.getRHS()->getType().cast<ToyArrayType>();
|
||||
auto lhsRank = lhs.getShape().size();
|
||||
auto rhsRank = rhs.getShape().size();
|
||||
if (lhsRank != rhsRank) {
|
||||
return mulOp.emitOpError(
|
||||
"shape mismatch: LHS and RHS must have the same "
|
||||
"rank for multiplication, got " +
|
||||
Twine(lhsRank) + " vs " + Twine(lhsRank));
|
||||
}
|
||||
SmallVector<int64_t, 2> dims;
|
||||
if (lhsRank == 1) {
|
||||
// dot product, result shape is <1>
|
||||
dims.push_back(1);
|
||||
} else if (lhsRank != 2) {
|
||||
return op->emitOpError(
|
||||
"shape mismatch: expect rank 1 or 2 for mul operands, got " +
|
||||
Twine(lhsRank));
|
||||
} else {
|
||||
dims.push_back(lhs.getShape()[0]);
|
||||
dims.push_back(rhs.getShape()[1]);
|
||||
}
|
||||
op->getResult(0)->setType(ToyArrayType::get(&getContext(), dims));
|
||||
continue;
|
||||
}
|
||||
|
||||
// Process calls: lookup the callee after mangling the name with the
|
||||
// argument shapes. If the callee does not exist, we stop the inference
|
||||
// for this function, queue the callee in the inter-procedural work list,
|
||||
// and return. The current function stays in the work list and will
|
||||
// restart after the callee is processed.
|
||||
if (auto callOp = llvm::dyn_cast<GenericCallOp>(op)) {
|
||||
auto calleeName = callOp.getCalleeName();
|
||||
auto callee = getModule().lookupSymbol<mlir::FuncOp>(calleeName);
|
||||
if (!callee) {
|
||||
f.emitError("shape inference failed, call to unknown '")
|
||||
<< calleeName << "'";
|
||||
signalPassFailure();
|
||||
return mlir::failure();
|
||||
}
|
||||
auto mangledName = mangle(calleeName, op->getOpOperands());
|
||||
LLVM_DEBUG(llvm::dbgs() << "Found callee to infer: '" << calleeName
|
||||
<< "', mangled: '" << mangledName << "'\n");
|
||||
auto mangledCallee =
|
||||
getModule().lookupSymbol<mlir::FuncOp>(mangledName);
|
||||
if (!mangledCallee) {
|
||||
// Can't find the target, this is where we queue the request for the
|
||||
// callee and stop the inference for the current function now.
|
||||
funcWorklist.push_back({callee, std::move(mangledName),
|
||||
llvm::to_vector<4>(op->getOperandTypes())});
|
||||
return mlir::success();
|
||||
}
|
||||
// Found a specialized callee! Let's turn this into a normal call
|
||||
// operation.
|
||||
SmallVector<mlir::Value *, 8> operands(op->getOperands());
|
||||
mlir::OpBuilder builder(op);
|
||||
auto newCall =
|
||||
builder.create<mlir::CallOp>(op->getLoc(), mangledCallee, operands);
|
||||
if (newCall.getNumResults()) {
|
||||
op->getResult(0)->replaceAllUsesWith(newCall.getResult(0));
|
||||
op->erase();
|
||||
continue;
|
||||
}
|
||||
}
|
||||
auto shapeOp = dyn_cast<ShapeInference>(op);
|
||||
shapeOp.inferShapes();
|
||||
}
|
||||
|
||||
// Done with inference on this function, removing it from the worklist.
|
||||
funcWorklist.pop_back();
|
||||
// Mark the function as non-generic now that inference has succeeded
|
||||
f.setAttr("toy.generic", mlir::BoolAttr::get(false, &getContext()));
|
||||
|
||||
// If the operation worklist isn't empty, this indicates a failure.
|
||||
if (!opWorklist.empty()) {
|
||||
std::string str;
|
||||
llvm::raw_string_ostream errorMsg(str);
|
||||
errorMsg << "shape inference failed, " << opWorklist.size()
|
||||
<< " operations couldn't be inferred\n";
|
||||
for (auto *ope : opWorklist)
|
||||
errorMsg << " - " << *ope << "\n";
|
||||
f.emitError(errorMsg.str());
|
||||
signalPassFailure();
|
||||
return mlir::failure();
|
||||
auto diag = f.emitError("Shape inference failed, ")
|
||||
<< opWorklist.size() << " operations couldn't be inferred\n";
|
||||
}
|
||||
|
||||
// Finally, update the return type of the function based on the argument to
|
||||
// the return operation.
|
||||
for (auto &block : f.getBlocks()) {
|
||||
auto ret = llvm::cast<ReturnOp>(block.getTerminator());
|
||||
if (!ret)
|
||||
continue;
|
||||
if (ret.getNumOperands() &&
|
||||
f.getType().getResult(0) == ret.getOperand()->getType())
|
||||
// type match, we're done
|
||||
break;
|
||||
SmallVector<mlir::Type, 1> retTy;
|
||||
if (ret.getNumOperands())
|
||||
retTy.push_back(ret.getOperand()->getType());
|
||||
std::vector<mlir::Type> argumentsType;
|
||||
for (auto arg : f.getArguments())
|
||||
argumentsType.push_back(arg->getType());
|
||||
auto newType =
|
||||
mlir::FunctionType::get(argumentsType, retTy, &getContext());
|
||||
f.setType(newType);
|
||||
assert(succeeded(mlir::verify(f)));
|
||||
break;
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
namespace toy {
|
||||
std::unique_ptr<mlir::Pass> createShapeInferencePass() {
|
||||
/// Create a Shape Inference pass.
|
||||
std::unique_ptr<mlir::Pass> mlir::toy::createShapeInferencePass() {
|
||||
return std::make_unique<ShapeInferencePass>();
|
||||
}
|
||||
} // namespace toy
|
||||
|
|
|
@ -15,24 +15,25 @@
|
|||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements a simple combiner for optimizing pattern in the Toy
|
||||
// dialect.
|
||||
// This file implements a set of simple combiners for optimizing operations in
|
||||
// the Toy dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "toy/Dialect.h"
|
||||
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
||||
#include "toy/Dialect.h"
|
||||
#include <numeric>
|
||||
|
||||
namespace toy {
|
||||
using namespace mlir;
|
||||
using namespace toy;
|
||||
|
||||
namespace {
|
||||
/// Include the patterns defined in the Declarative Rewrite framework.
|
||||
#include "ToyCombine.inc"
|
||||
} // end anonymous namespace
|
||||
|
||||
/// Fold transpose(transpose(x) -> transpose(x)
|
||||
/// This is an example of a c++ rewrite pattern for the TransposeOp. It
|
||||
/// optimizes the following scenario: transpose(transpose(x)) -> transpose(x)
|
||||
struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
|
||||
/// We register this pattern to match every toy.transpose in the IR.
|
||||
/// The "benefit" is used by the framework to order the patterns and process
|
||||
|
@ -40,9 +41,9 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
|
|||
SimplifyRedundantTranspose(mlir::MLIRContext *context)
|
||||
: OpRewritePattern<TransposeOp>(context, /*benefit=*/1) {}
|
||||
|
||||
/// This method is attempting to match a pattern and rewrite it. The rewriter
|
||||
/// argument is the orchestrator of the sequence of rewrites. It is expected
|
||||
/// to interact with it to perform any changes to the IR from here.
|
||||
/// This method attempts to match a pattern and rewrite it. The rewriter
|
||||
/// argument is the orchestrator of the sequence of rewrites. The pattern is
|
||||
/// expected to interact with it to perform any changes to the IR from here.
|
||||
mlir::PatternMatchResult
|
||||
matchAndRewrite(TransposeOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
|
@ -50,106 +51,28 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
|
|||
mlir::Value *transposeInput = op.getOperand();
|
||||
TransposeOp transposeInputOp =
|
||||
llvm::dyn_cast_or_null<TransposeOp>(transposeInput->getDefiningOp());
|
||||
|
||||
// If the input is defined by another Transpose, bingo!
|
||||
if (!transposeInputOp)
|
||||
return matchFailure();
|
||||
|
||||
// Use the rewriter to perform the replacement
|
||||
// Use the rewriter to perform the replacement.
|
||||
rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp});
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
/// Fold reshape(constant(x)) -> constant(x'), with x' being reshaped in place.
|
||||
struct SimplifyReshapeConstant : public mlir::OpRewritePattern<ReshapeOp> {
|
||||
using mlir::OpRewritePattern<ReshapeOp>::OpRewritePattern;
|
||||
|
||||
mlir::PatternMatchResult
|
||||
matchAndRewrite(ReshapeOp reshape,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
// Look through the input of the current reshape.
|
||||
ConstantOp constantOp = llvm::dyn_cast_or_null<ConstantOp>(
|
||||
reshape.getOperand()->getDefiningOp());
|
||||
// If the input is defined by another constant, bingo!
|
||||
if (!constantOp)
|
||||
return matchFailure();
|
||||
|
||||
auto reshapeType = reshape.getType().cast<ToyArrayType>();
|
||||
if (auto valueAttr =
|
||||
constantOp.getAttrOfType<mlir::DenseElementsAttr>("value")) {
|
||||
// FIXME Check matching of element count!
|
||||
// auto oldType = constantOp.getType();
|
||||
auto newType = rewriter.getTensorType(
|
||||
reshapeType.getShape(), valueAttr.getType().getElementType());
|
||||
auto newAttr = valueAttr.reshape(newType);
|
||||
rewriter.replaceOpWithNewOp<ConstantOp>(reshape, reshapeType.getShape(),
|
||||
newAttr);
|
||||
} else if (auto valueAttr =
|
||||
constantOp.getAttrOfType<mlir::FloatAttr>("value")) {
|
||||
// Broadcast
|
||||
auto dataSize = std::accumulate(reshapeType.getShape().begin(),
|
||||
reshapeType.getShape().end(), 1,
|
||||
std::multiplies<int>());
|
||||
std::vector<mlir::Attribute> data(dataSize, valueAttr);
|
||||
auto tensorTy = rewriter.getTensorType(reshapeType.getShape(),
|
||||
reshapeType.getElementType());
|
||||
auto newAttr = mlir::DenseElementsAttr::get(tensorTy, data);
|
||||
rewriter.replaceOpWithNewOp<ConstantOp>(reshape, reshapeType.getShape(),
|
||||
newAttr);
|
||||
} else {
|
||||
llvm_unreachable("Unsupported Constant format");
|
||||
}
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
/// Fold reshape(reshape(x)) -> reshape(x)
|
||||
struct SimplifyReshapeReshape : public mlir::OpRewritePattern<ReshapeOp> {
|
||||
using mlir::OpRewritePattern<ReshapeOp>::OpRewritePattern;
|
||||
|
||||
mlir::PatternMatchResult
|
||||
matchAndRewrite(ReshapeOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
// Look through the input of the current reshape.
|
||||
mlir::Value *reshapeInput = op.getOperand();
|
||||
|
||||
// If the input is defined by another reshape, bingo!
|
||||
if (!matchPattern(reshapeInput, mlir::m_Op<ReshapeOp>()))
|
||||
return matchFailure();
|
||||
|
||||
// Use the rewriter to perform the replacement
|
||||
rewriter.replaceOp(op, {reshapeInput});
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
/// Fold reshape(x)) -> x, when input type matches output type
|
||||
struct SimplifyNullReshape : public mlir::OpRewritePattern<ReshapeOp> {
|
||||
using mlir::OpRewritePattern<ReshapeOp>::OpRewritePattern;
|
||||
|
||||
mlir::PatternMatchResult
|
||||
matchAndRewrite(ReshapeOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
if (op.getOperand()->getType() != op.getType())
|
||||
return matchFailure();
|
||||
rewriter.replaceOp(op, {op.getOperand()});
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
} // end anonymous namespace.
|
||||
|
||||
// Register our patterns for rewrite by the Canonicalization framework.
|
||||
void TransposeOp::getCanonicalizationPatterns(
|
||||
mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) {
|
||||
/// Register our patterns as "canonicalization" patterns on the TransposeOp so
|
||||
/// that they can be picked up by the Canonicalization framework.
|
||||
void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.insert<SimplifyRedundantTranspose>(context);
|
||||
}
|
||||
|
||||
// Register our patterns for rewrite by the Canonicalization framework.
|
||||
void ReshapeOp::getCanonicalizationPatterns(
|
||||
mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) {
|
||||
results.insert<SimplifyReshapeConstant, SimplifyReshapeReshape,
|
||||
SimplifyNullReshape>(context);
|
||||
/// Register our patterns as "canonicalization" patterns on the ReshapeOp so
|
||||
/// that they can be picked up by the Canonicalization framework.
|
||||
void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.insert<ReshapeReshapeOptPattern, RedundantReshapeOptPattern,
|
||||
FoldConstantReshapeOptPattern>(context);
|
||||
}
|
||||
|
||||
} // namespace toy
|
||||
|
|
|
@ -0,0 +1,73 @@
|
|||
//===- ToyCombine.td - Pattern Match Optimizations for Toy -*- tablegen -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// Defines language-specific pattern match optimizations for Toy using
|
||||
// Declarative Rewrite Rules (DRR) specified using TableGen records.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TOY_COMBINE
|
||||
#define TOY_COMBINE
|
||||
|
||||
#ifndef OP_BASE
|
||||
include "toy/Ops.td"
|
||||
#endif // OP_BASE
|
||||
|
||||
/// Note: The DRR definition used for defining patterns is shown below:
|
||||
///
|
||||
/// class Pattern<
|
||||
/// dag sourcePattern, list<dag> resultPatterns,
|
||||
/// list<dag> additionalConstraints = [],
|
||||
/// dag benefitsAdded = (addBenefit 0)
|
||||
/// >;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Basic Pattern-Match and Rewrite
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Reshape(Reshape(x)) = x
|
||||
def ReshapeReshapeOptPattern : Pat<(ReshapeOp(ReshapeOp $arg)),
|
||||
(ReshapeOp $arg)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Pattern-Match and Rewrite using Native Code Call
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Native Code Calls may be used for more complex transformations using inline
|
||||
// C++ and C++ helper functions.
|
||||
|
||||
// Reshape(Constant(x)) = x'
|
||||
def ReshapeConstant :
|
||||
NativeCodeCall<"$0.reshape(($1->getType()).cast<ShapedType>())">;
|
||||
def FoldConstantReshapeOptPattern : Pat<
|
||||
(ReshapeOp:$res (ConstantOp $arg)),
|
||||
(ConstantOp (ReshapeConstant $arg, $res))>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Pattern-Match and Rewrite with Constraints
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// DRR allows for constraint checking when the transformation is conditional
|
||||
// on operand properties.
|
||||
|
||||
// Reshape(x) = x, where input and output shapes are identical
|
||||
def TypesAreIdentical : Constraint<CPred<"$0->getType() == $1->getType()">>;
|
||||
def RedundantReshapeOptPattern : Pat<
|
||||
(ReshapeOp:$res $arg), (replaceWithValue $arg),
|
||||
[(TypesAreIdentical $res, $arg)]>;
|
||||
|
||||
#endif // TOY_COMBINE
|
|
@ -1,387 +0,0 @@
|
|||
//===- ToyDialect.cpp - Toy IR Dialect registration in MLIR ---------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements the dialect for the Toy IR: custom type parsing and
|
||||
// operation verification.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "toy/Dialect.h"
|
||||
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Support/STLExtras.h"
|
||||
#include "llvm/ADT/iterator_range.h"
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
#include "llvm/Support/Regex.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
using llvm::ArrayRef;
|
||||
using llvm::raw_ostream;
|
||||
using llvm::raw_string_ostream;
|
||||
using llvm::SmallVector;
|
||||
using llvm::StringRef;
|
||||
using llvm::Twine;
|
||||
|
||||
namespace toy {
|
||||
namespace detail {
|
||||
|
||||
/// This class holds the implementation of the ToyArrayType.
|
||||
/// It is intended to be uniqued based on its content and owned by the context.
|
||||
struct ToyArrayTypeStorage : public mlir::TypeStorage {
|
||||
/// This defines how we unique this type in the context: our key contains
|
||||
/// only the shape, a more complex type would have multiple entries in the
|
||||
/// tuple here.
|
||||
/// The element of the tuples usually matches 1-1 the arguments from the
|
||||
/// public `get()` method arguments from the facade.
|
||||
using KeyTy = std::tuple<ArrayRef<int64_t>>;
|
||||
static unsigned hashKey(const KeyTy &key) {
|
||||
return llvm::hash_combine(std::get<0>(key));
|
||||
}
|
||||
/// When the key hash hits an existing type, we compare the shape themselves
|
||||
/// to confirm we have the right type.
|
||||
bool operator==(const KeyTy &key) const { return key == KeyTy(getShape()); }
|
||||
|
||||
/// This is a factory method to create our type storage. It is only
|
||||
/// invoked after looking up the type in the context using the key and not
|
||||
/// finding it.
|
||||
static ToyArrayTypeStorage *construct(mlir::TypeStorageAllocator &allocator,
|
||||
const KeyTy &key) {
|
||||
// Copy the shape array into the bumpptr allocator owned by the context.
|
||||
ArrayRef<int64_t> shape = allocator.copyInto(std::get<0>(key));
|
||||
|
||||
// Allocate the instance for the ToyArrayTypeStorage itself
|
||||
auto *storage = allocator.allocate<ToyArrayTypeStorage>();
|
||||
// Initialize the instance using placement new.
|
||||
return new (storage) ToyArrayTypeStorage(shape);
|
||||
}
|
||||
|
||||
ArrayRef<int64_t> getShape() const { return shape; }
|
||||
|
||||
private:
|
||||
ArrayRef<int64_t> shape;
|
||||
|
||||
/// Constructor is only invoked from the `construct()` method above.
|
||||
ToyArrayTypeStorage(ArrayRef<int64_t> shape) : shape(shape) {}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
mlir::Type ToyArrayType::getElementType() {
|
||||
return mlir::FloatType::getF64(getContext());
|
||||
}
|
||||
|
||||
ToyArrayType ToyArrayType::get(mlir::MLIRContext *context,
|
||||
ArrayRef<int64_t> shape) {
|
||||
return Base::get(context, ToyTypeKind::TOY_ARRAY, shape);
|
||||
}
|
||||
|
||||
ArrayRef<int64_t> ToyArrayType::getShape() { return getImpl()->getShape(); }
|
||||
|
||||
/// Dialect creation, the instance will be owned by the context. This is the
|
||||
/// point of registration of custom types and operations for the dialect.
|
||||
ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
|
||||
addOperations<ConstantOp, GenericCallOp, PrintOp, TransposeOp, ReshapeOp,
|
||||
MulOp, AddOp, ReturnOp>();
|
||||
addTypes<ToyArrayType>();
|
||||
}
|
||||
|
||||
/// Parse a type registered to this dialect, we expect only Toy arrays.
|
||||
mlir::Type ToyDialect::parseType(StringRef tyData, mlir::Location loc) const {
|
||||
// Sanity check: we only support array or array<...>
|
||||
if (!tyData.startswith("array")) {
|
||||
emitError(loc, "invalid Toy type '" + tyData + "', array expected");
|
||||
return nullptr;
|
||||
}
|
||||
// Drop the "array" prefix from the type name, we expect either an empty
|
||||
// string or just the shape.
|
||||
tyData = tyData.drop_front(StringRef("array").size());
|
||||
// This is the generic array case without shape, early return it.
|
||||
if (tyData.empty())
|
||||
return ToyArrayType::get(getContext());
|
||||
|
||||
// Use a regex to parse the shape (for efficient we should store this regex in
|
||||
// the dialect itself).
|
||||
SmallVector<StringRef, 4> matches;
|
||||
auto shapeRegex = llvm::Regex("^<([0-9]+)(, ([0-9]+))*>$");
|
||||
if (!shapeRegex.match(tyData, &matches)) {
|
||||
emitError(loc, "invalid toy array shape '" + tyData + "'");
|
||||
return nullptr;
|
||||
}
|
||||
SmallVector<int64_t, 4> shape;
|
||||
// Iterate through the captures, skip the first one which is the full string.
|
||||
for (auto dimStr :
|
||||
llvm::make_range(std::next(matches.begin()), matches.end())) {
|
||||
if (dimStr.startswith(","))
|
||||
continue; // POSIX misses non-capturing groups.
|
||||
if (dimStr.empty())
|
||||
continue; // '*' makes it an optional group capture
|
||||
// Convert the capture to an integer
|
||||
unsigned long long dim;
|
||||
if (getAsUnsignedInteger(dimStr, /* Radix = */ 10, dim)) {
|
||||
emitError(loc, "couldn't parse dimension as integer, matched: " + dimStr);
|
||||
return mlir::Type();
|
||||
}
|
||||
shape.push_back(dim);
|
||||
}
|
||||
// Finally we collected all the dimensions in the shape,
|
||||
// create the array type.
|
||||
return ToyArrayType::get(getContext(), shape);
|
||||
}
|
||||
|
||||
/// Print a Toy array type, for example `array<2, 3, 4>`
|
||||
void ToyDialect::printType(mlir::Type type, raw_ostream &os) const {
|
||||
auto arrayTy = type.dyn_cast<ToyArrayType>();
|
||||
if (!arrayTy) {
|
||||
os << "unknown toy type";
|
||||
return;
|
||||
}
|
||||
os << "array";
|
||||
if (!arrayTy.getShape().empty()) {
|
||||
os << "<";
|
||||
mlir::interleaveComma(arrayTy.getShape(), os);
|
||||
os << ">";
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//////////////////// Custom Operations for the Dialect /////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to verify that the result of an operation is a Toy array type.
|
||||
template <typename T> static mlir::LogicalResult verifyToyReturnArray(T *op) {
|
||||
if (!op->getResult()->getType().template isa<ToyArrayType>()) {
|
||||
std::string msg;
|
||||
raw_string_ostream os(msg);
|
||||
os << "expects a Toy Array for its argument, got "
|
||||
<< op->getResult()->getType();
|
||||
return op->emitOpError(os.str());
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
/// Helper to verify that the two operands of a binary operation are Toy
|
||||
/// arrays..
|
||||
template <typename T> static mlir::LogicalResult verifyToyBinOperands(T *op) {
|
||||
if (!op->getOperand(0)->getType().template isa<ToyArrayType>()) {
|
||||
std::string msg;
|
||||
raw_string_ostream os(msg);
|
||||
os << "expects a Toy Array for its LHS, got "
|
||||
<< op->getOperand(0)->getType();
|
||||
return op->emitOpError(os.str());
|
||||
}
|
||||
if (!op->getOperand(1)->getType().template isa<ToyArrayType>()) {
|
||||
std::string msg;
|
||||
raw_string_ostream os(msg);
|
||||
os << "expects a Toy Array for its LHS, got "
|
||||
<< op->getOperand(0)->getType();
|
||||
return op->emitOpError(os.str());
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
/// Build a constant operation.
|
||||
/// The builder is passed as an argument, so is the state that this method is
|
||||
/// expected to fill in order to build the operation.
|
||||
void ConstantOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
ArrayRef<int64_t> shape, mlir::DenseElementsAttr value) {
|
||||
state.types.push_back(ToyArrayType::get(builder->getContext(), shape));
|
||||
auto dataAttribute = builder->getNamedAttr("value", value);
|
||||
state.attributes.push_back(dataAttribute);
|
||||
}
|
||||
|
||||
/// Build a constant operation.
|
||||
/// The builder is passed as an argument, so is the state that this method is
|
||||
/// expected to fill in order to build the operation.
|
||||
void ConstantOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
mlir::FloatAttr value) {
|
||||
// Broadcast and forward to the other build factory
|
||||
mlir::Type elementType = mlir::FloatType::getF64(builder->getContext());
|
||||
auto dataType = builder->getTensorType({1}, elementType);
|
||||
auto dataAttribute = builder->getDenseElementsAttr(dataType, {value})
|
||||
.cast<mlir::DenseElementsAttr>();
|
||||
|
||||
ConstantOp::build(builder, state, {1}, dataAttribute);
|
||||
}
|
||||
|
||||
/// Verifier for constant operation.
|
||||
mlir::LogicalResult ConstantOp::verify() {
|
||||
// Ensure that the return type is a Toy array
|
||||
if (failed(verifyToyReturnArray(this)))
|
||||
return mlir::failure();
|
||||
|
||||
// We expect the constant itself to be stored as an attribute.
|
||||
auto dataAttr = getAttr("value").dyn_cast<mlir::DenseElementsAttr>();
|
||||
if (!dataAttr) {
|
||||
return emitOpError(
|
||||
"missing valid `value` DenseElementsAttribute on toy.constant()");
|
||||
}
|
||||
auto attrType = dataAttr.getType().dyn_cast<mlir::TensorType>();
|
||||
if (!attrType) {
|
||||
return emitOpError(
|
||||
"missing valid `value` DenseElementsAttribute on toy.constant()");
|
||||
}
|
||||
|
||||
// If the return type of the constant is not a generic array, the shape must
|
||||
// match the shape of the attribute holding the data.
|
||||
auto resultType = getResult()->getType().cast<ToyArrayType>();
|
||||
if (!resultType.isGeneric()) {
|
||||
if (attrType.getRank() != resultType.getRank()) {
|
||||
return emitOpError("The rank of the toy.constant return type must match "
|
||||
"the one of the attached value attribute: " +
|
||||
Twine(attrType.getRank()) +
|
||||
" != " + Twine(resultType.getRank()));
|
||||
}
|
||||
for (int dim = 0; dim < attrType.getRank(); ++dim) {
|
||||
if (attrType.getShape()[dim] != resultType.getShape()[dim]) {
|
||||
std::string msg;
|
||||
raw_string_ostream os(msg);
|
||||
return emitOpError(
|
||||
"Shape mismatch between toy.constant return type and its "
|
||||
"attribute at dimension " +
|
||||
Twine(dim) + ": " + Twine(attrType.getShape()[dim]) +
|
||||
" != " + Twine(resultType.getShape()[dim]));
|
||||
}
|
||||
}
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
StringRef callee, ArrayRef<mlir::Value *> arguments) {
|
||||
// Generic call always returns a generic ToyArray initially
|
||||
state.types.push_back(ToyArrayType::get(builder->getContext()));
|
||||
state.operands.assign(arguments.begin(), arguments.end());
|
||||
auto calleeAttr = builder->getStringAttr(callee);
|
||||
state.attributes.push_back(builder->getNamedAttr("callee", calleeAttr));
|
||||
}
|
||||
|
||||
mlir::LogicalResult GenericCallOp::verify() {
|
||||
// Verify that every operand is a Toy Array
|
||||
for (int opId = 0, num = getNumOperands(); opId < num; ++opId) {
|
||||
if (!getOperand(opId)->getType().template isa<ToyArrayType>()) {
|
||||
std::string msg;
|
||||
raw_string_ostream os(msg);
|
||||
os << "expects a Toy Array for its " << opId << " operand, got "
|
||||
<< getOperand(opId)->getType();
|
||||
return emitOpError(os.str());
|
||||
}
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
/// Return the name of the callee.
|
||||
StringRef GenericCallOp::getCalleeName() {
|
||||
return getAttr("callee").cast<mlir::StringAttr>().getValue();
|
||||
}
|
||||
|
||||
template <typename T> static mlir::LogicalResult verifyToySingleOperand(T *op) {
|
||||
if (!op->getOperand()->getType().template isa<ToyArrayType>()) {
|
||||
std::string msg;
|
||||
raw_string_ostream os(msg);
|
||||
os << "expects a Toy Array for its argument, got "
|
||||
<< op->getOperand()->getType();
|
||||
return op->emitOpError(os.str());
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
void ReturnOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
mlir::Value *value) {
|
||||
// Return does not return any value and has an optional single argument
|
||||
if (value)
|
||||
state.operands.push_back(value);
|
||||
}
|
||||
|
||||
mlir::LogicalResult ReturnOp::verify() {
|
||||
if (getNumOperands() > 1)
|
||||
return emitOpError("expects zero or one operand, got " +
|
||||
Twine(getNumOperands()));
|
||||
if (hasOperand() && failed(verifyToySingleOperand(this)))
|
||||
return mlir::failure();
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
void PrintOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
mlir::Value *value) {
|
||||
// Print does not return any value and has a single argument
|
||||
state.operands.push_back(value);
|
||||
}
|
||||
|
||||
mlir::LogicalResult PrintOp::verify() {
|
||||
if (failed(verifyToySingleOperand(this)))
|
||||
return mlir::failure();
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
void TransposeOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
mlir::Value *value) {
|
||||
state.types.push_back(ToyArrayType::get(builder->getContext()));
|
||||
state.operands.push_back(value);
|
||||
}
|
||||
|
||||
mlir::LogicalResult TransposeOp::verify() {
|
||||
if (failed(verifyToySingleOperand(this)))
|
||||
return mlir::failure();
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
void ReshapeOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
mlir::Value *value, ToyArrayType reshapedType) {
|
||||
state.types.push_back(reshapedType);
|
||||
state.operands.push_back(value);
|
||||
}
|
||||
|
||||
mlir::LogicalResult ReshapeOp::verify() {
|
||||
if (failed(verifyToySingleOperand(this)))
|
||||
return mlir::failure();
|
||||
auto retTy = getResult()->getType().dyn_cast<ToyArrayType>();
|
||||
if (!retTy)
|
||||
return emitOpError("toy.reshape is expected to produce a Toy array");
|
||||
if (retTy.isGeneric())
|
||||
return emitOpError("toy.reshape is expected to produce a shaped Toy array, "
|
||||
"got a generic one.");
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
void AddOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
mlir::Value *lhs, mlir::Value *rhs) {
|
||||
state.types.push_back(ToyArrayType::get(builder->getContext()));
|
||||
state.operands.push_back(lhs);
|
||||
state.operands.push_back(rhs);
|
||||
}
|
||||
|
||||
mlir::LogicalResult AddOp::verify() {
|
||||
if (failed(verifyToyBinOperands(this)))
|
||||
return mlir::failure();
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
void MulOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
mlir::Value *lhs, mlir::Value *rhs) {
|
||||
state.types.push_back(ToyArrayType::get(builder->getContext()));
|
||||
state.operands.push_back(lhs);
|
||||
state.operands.push_back(rhs);
|
||||
}
|
||||
|
||||
mlir::LogicalResult MulOp::verify() {
|
||||
if (failed(verifyToyBinOperands(this)))
|
||||
return mlir::failure();
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
} // namespace toy
|
|
@ -80,54 +80,63 @@ std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
|
|||
return parser.ParseModule();
|
||||
}
|
||||
|
||||
mlir::LogicalResult optimize(mlir::ModuleOp module) {
|
||||
mlir::PassManager pm(module.getContext());
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
pm.addPass(createShapeInferencePass());
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
// Apply any generic pass manager command line options.
|
||||
applyPassManagerCLOptions(pm);
|
||||
int loadMLIR(mlir::MLIRContext &context, mlir::OwningModuleRef &module) {
|
||||
// Handle '.toy' input to the compiler.
|
||||
if (inputType != InputType::MLIR &&
|
||||
!llvm::StringRef(inputFilename).endswith(".mlir")) {
|
||||
auto moduleAST = parseInputFile(inputFilename);
|
||||
module = mlirGen(context, *moduleAST);
|
||||
return !module ? 1 : 0;
|
||||
}
|
||||
|
||||
return pm.run(module);
|
||||
// Otherwise, the input is '.mlir'.
|
||||
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
|
||||
llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
|
||||
if (std::error_code EC = fileOrErr.getError()) {
|
||||
llvm::errs() << "Could not open input file: " << EC.message() << "\n";
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Parse the input mlir.
|
||||
llvm::SourceMgr sourceMgr;
|
||||
sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
|
||||
module = mlir::parseSourceFile(sourceMgr, &context);
|
||||
if (!module) {
|
||||
llvm::errs() << "Error can't load file " << inputFilename << "\n";
|
||||
return 3;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
int dumpMLIR() {
|
||||
// Register our Dialect with MLIR
|
||||
mlir::registerDialect<ToyDialect>();
|
||||
// Register our Dialect with MLIR.
|
||||
mlir::registerDialect<mlir::toy::ToyDialect>();
|
||||
|
||||
mlir::MLIRContext context;
|
||||
mlir::OwningModuleRef module;
|
||||
if (inputType == InputType::MLIR ||
|
||||
llvm::StringRef(inputFilename).endswith(".mlir")) {
|
||||
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
|
||||
llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
|
||||
if (std::error_code EC = fileOrErr.getError()) {
|
||||
llvm::errs() << "Could not open input file: " << EC.message() << "\n";
|
||||
return -1;
|
||||
}
|
||||
llvm::SourceMgr sourceMgr;
|
||||
sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
|
||||
module = mlir::parseSourceFile(sourceMgr, &context);
|
||||
if (!module) {
|
||||
llvm::errs() << "Error can't load file " << inputFilename << "\n";
|
||||
return 3;
|
||||
}
|
||||
if (failed(mlir::verify(*module))) {
|
||||
llvm::errs() << "Error verifying MLIR module\n";
|
||||
return 4;
|
||||
}
|
||||
} else {
|
||||
auto moduleAST = parseInputFile(inputFilename);
|
||||
module = mlirGen(context, *moduleAST);
|
||||
}
|
||||
if (!module)
|
||||
return 1;
|
||||
if (int error = loadMLIR(context, module))
|
||||
return error;
|
||||
|
||||
if (EnableOpt) {
|
||||
if (failed(optimize(*module))) {
|
||||
llvm::errs() << "Module optimization failed\n";
|
||||
return 7;
|
||||
}
|
||||
mlir::PassManager pm(&context);
|
||||
// Apply any generic pass manager command line options and run the pipeline.
|
||||
applyPassManagerCLOptions(pm);
|
||||
|
||||
// Add a run of the canonicalizer to optimize the mlir module.
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
|
||||
// Inline all functions into main and then delete them.
|
||||
pm.addPass(mlir::createInlinerPass());
|
||||
pm.addPass(mlir::toy::createDeadFunctionEliminationPass());
|
||||
|
||||
// Now that there is only one function, we can infer the shapes of each of
|
||||
// the operations.
|
||||
pm.addPass(mlir::toy::createShapeInferencePass());
|
||||
|
||||
if (mlir::failed(pm.run(*module)))
|
||||
return 4;
|
||||
}
|
||||
|
||||
module->dump();
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -1,242 +1,118 @@
|
|||
# Chapter 4: High-level Language-Specific Analysis and Transformation
|
||||
# Chapter 4: Using Interfaces
|
||||
|
||||
Creating a dialect that closely represents the semantics of an input language
|
||||
enables analyses and transformations in MLIR that are generally performed on the
|
||||
language AST. For example, `clang` has a fairly
|
||||
[heavy mechanism](https://clang.llvm.org/doxygen/classclang_1_1TreeTransform.html)
|
||||
for performing template instantiation in C++.
|
||||
[Interfaces](../../Interfaces.md) provide a generic method for applying
|
||||
transformations across dialects. We first describe how to leverage an existing
|
||||
MLIR interface, and then walk through writing your own interface.
|
||||
|
||||
Another aspect is optimization. While some previous language specific
|
||||
optimizations have been implemented in LLVM (like the
|
||||
[ARC optimizer](http://llvm.org/doxygen/ObjCARCOpts_8cpp_source.html#l00468)),
|
||||
it has been at the cost of relying on either adding enough concepts in LLVM, to
|
||||
be able to embed the high-level semantics of the input, or using fragile
|
||||
"best-effort" metadata to decorate the IR with the information needed for these
|
||||
custom optimizations.
|
||||
## Function Inlining
|
||||
|
||||
We show in this chapter how to leverage the Toy Dialect and its high-level
|
||||
semantics to perform transformations that would be difficult in LLVM: first a
|
||||
simple combine of two redundant operations, and second a full interprocedural
|
||||
shape inference with function specialization.
|
||||
In order to apply function inlining in the Toy dialect, we override the
|
||||
DialectInlinerInterface in Toy, enable inlining and add special handling for the
|
||||
return operation:
|
||||
|
||||
# Basic Optimization: Eliminate Redundant Transpose
|
||||
```Toy(.cpp)
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ToyInlinerInterface
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Let's start with a simple pattern and try to eliminate a sequence of two
|
||||
transpose that cancel out: `transpose(transpose(X)) -> X`. Here is the
|
||||
corresponding Toy example:
|
||||
/// This class defines the interface for handling inlining with Toy
|
||||
/// operations.
|
||||
struct ToyInlinerInterface : public DialectInlinerInterface {
|
||||
using DialectInlinerInterface::DialectInlinerInterface;
|
||||
|
||||
```Toy(.toy)
|
||||
def transpose_transpose(x) {
|
||||
return transpose(transpose(x));
|
||||
}
|
||||
```
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Analysis Hooks
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
Which corresponds to the following IR:
|
||||
|
||||
```MLIR(.mlir)
|
||||
func @transpose_transpose(%arg0: !toy<"array">)
|
||||
attributes {toy.generic: true} {
|
||||
%0 = "toy.transpose"(%arg0) : (!toy<"array">) -> !toy<"array">
|
||||
%1 = "toy.transpose"(%0) : (!toy<"array">) -> !toy<"array">
|
||||
"toy.return"(%1) : (!toy<"array">) -> ()
|
||||
}
|
||||
```
|
||||
|
||||
This is a good example of a transformation that is trivial to match on the Toy
|
||||
IR but that would be quite hard for LLVM to figure. For example today clang
|
||||
can't optimize away the temporary array and the computation with the naive
|
||||
transpose expressed with these loops:
|
||||
|
||||
```c++
|
||||
#define N 100
|
||||
#define M 100
|
||||
|
||||
void sink(void *);
|
||||
void double_transpose(int A[N][M]) {
|
||||
int B[M][N];
|
||||
for(int i = 0; i < N; ++i) {
|
||||
for(int j = 0; j < M; ++j) {
|
||||
B[j][i] = A[i][j];
|
||||
}
|
||||
/// All operations within toy can be inlined.
|
||||
bool isLegalToInline(Operation *, Region *,
|
||||
BlockAndValueMapping &) const final {
|
||||
return true;
|
||||
}
|
||||
for(int i = 0; i < N; ++i) {
|
||||
for(int j = 0; j < M; ++j) {
|
||||
A[i][j] = B[j][i];
|
||||
}
|
||||
}
|
||||
sink(A);
|
||||
}
|
||||
```
|
||||
|
||||
For simple rewrite involving matching a tree-like pattern in the IR and
|
||||
replacing it with a different set of operations, we can plug into the MLIR
|
||||
`Canonicalizer` pass by implementing a `RewritePattern`:
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Transformation Hooks
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
```c++
|
||||
/// Fold transpose(transpose(x)) -> x
|
||||
struct SimplifyRedundantTranspose : public mlir::RewritePattern {
|
||||
/// We register this pattern to match every toy.transpose in the IR.
|
||||
/// The "benefit" is used by the framework to order the patterns and process
|
||||
/// them in order of profitability.
|
||||
SimplifyRedundantTranspose(mlir::MLIRContext *context)
|
||||
: RewritePattern(TransposeOp::getOperationName(), /* benefit = */ 1, context) {}
|
||||
/// Handle the given inlined terminator(toy.return) by replacing it with a new operation
|
||||
/// as necessary.
|
||||
void handleTerminator(Operation *op,
|
||||
ArrayRef<Value *> valuesToRepl) const final {
|
||||
// Only "toy.return" needs to be handled here.
|
||||
auto returnOp = cast<ReturnOp>(op);
|
||||
|
||||
/// This method is attempting to match a pattern and rewrite it. The rewriter
|
||||
/// argument is the orchestrator of the sequence of rewrites. It is expected
|
||||
/// to interact with it to perform any changes to the IR from here.
|
||||
mlir::PatternMatchResult matchAndRewrite(
|
||||
mlir::Operation *op, mlir::PatternRewriter &rewriter) const override {
|
||||
// We can directly cast the current operation as this will only get invoked
|
||||
// on TransposeOp.
|
||||
TransposeOp transpose = op->cast<TransposeOp>();
|
||||
// look through the input to the current transpose
|
||||
mlir::Value *transposeInput = transpose.getOperand();
|
||||
// If the input is defined by another Transpose, bingo!
|
||||
if (!matchPattern(transposeInput, mlir::m_Op<TransposeOp>()))
|
||||
return matchFailure();
|
||||
|
||||
auto transposeInputOp =
|
||||
transposeInput->getDefiningOp()->cast<TransposeOp>();
|
||||
// Use the rewriter to perform the replacement
|
||||
rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp});
|
||||
return matchSuccess();
|
||||
// Replace the values directly with the return operands.
|
||||
assert(returnOp.getNumOperands() == valuesToRepl.size());
|
||||
for (const auto &it : llvm::enumerate(returnOp.getOperands()))
|
||||
valuesToRepl[it.index()]->replaceAllUsesWith(it.value());
|
||||
}
|
||||
};
|
||||
```
|
||||
|
||||
Let's see how to improve our `TransposeOp` by extending it with a new static
|
||||
method:
|
||||
Next, we call into the interface by adding an inliner pass to the pass manager
|
||||
for toy:
|
||||
|
||||
```c++
|
||||
/// This hook returns any canonicalization pattern rewrites that the operation
|
||||
/// supports, for use by the canonicalization pass.
|
||||
static void getCanonicalizationPatterns(mlir::OwningRewritePatternList &results,
|
||||
mlir::MLIRContext *context) {
|
||||
results.push_back(std::make_unique<SimplifyRedundantTranspose>(context));
|
||||
```Toy(.cpp)
|
||||
pm.addPass(mlir::createInlinerPass());
|
||||
```
|
||||
|
||||
** Insert example here **
|
||||
|
||||
## Shape Inference
|
||||
|
||||
The Toy language allows for implicit shapes and hence requires shape inference.
|
||||
We implement shape inference as a generic
|
||||
[Operation Interface](../../Interfaces.md#operation-interfaces).
|
||||
|
||||
1. We first create the ShapeInferenceOpInterface by specializing the
|
||||
OpInterface class using [ODS](../../OpDefinitions.md#operation-interfaces).
|
||||
This class defines interface methods that Toy operations must override for
|
||||
shape inference.
|
||||
|
||||
```Toy(.cpp)
|
||||
def ShapeInferenceOpInterface : OpInterface<"ShapeInferenceOpInterface"> {
|
||||
let methods = [
|
||||
InterfaceMethod<
|
||||
"bool", "returnsGenericArray", (ins), [{
|
||||
if (getNumResults() == 1) {
|
||||
auto arrayTy = op.getResult()->getType().cast<RankedTensorType>();
|
||||
return arrayTy.getShape().empty();
|
||||
}
|
||||
return false;
|
||||
}]>,
|
||||
InterfaceMethod<"void", "inferShapes", (ins), [{}]>
|
||||
];
|
||||
}
|
||||
```
|
||||
|
||||
1. Next, we override the inferShapes() method within Toy operations. As an
|
||||
example, for the transpose op, the result shape is inferred by swapping the
|
||||
dimensions of the input tensor.
|
||||
|
||||
```Toy(.cpp)
|
||||
void inferShapes() {
|
||||
SmallVector<int64_t, 2> dims;
|
||||
auto arrayTy = getOperand()->getType().cast<RankedTensorType>();
|
||||
dims.insert(dims.end(), arrayTy.getShape().begin(),
|
||||
arrayTy.getShape().end());
|
||||
if (dims.size() == 2)
|
||||
std::swap(dims[0], dims[1]);
|
||||
getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
||||
return;
|
||||
}
|
||||
```
|
||||
|
||||
The implementation of this rewriter is in `ToyCombine.cpp`. We also need to
|
||||
update our main file, `toyc.cpp`, to add an optimization pipeline. In MLIR, the
|
||||
optimizations are ran through a `PassManager` in a similar way to LLVM:
|
||||
1. We then create a generic ShapeInference Function pass that uses operation
|
||||
casting to access the inferShapes() method. This is an intraprocedural shape
|
||||
inference pass that executes after function inlining and iterates over
|
||||
operations in a worklist calling inferShapes for each operation with unknown
|
||||
result shapes.
|
||||
|
||||
```c++
|
||||
mlir::PassManager pm(ctx);
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
pm.run(&module);
|
||||
2. Finally, we call into shape inference pass by adding it to the pass manager
|
||||
for toy:
|
||||
|
||||
```Toy(.cpp)
|
||||
pm.addPass(mlir::createShapeInferencePass());
|
||||
```
|
||||
|
||||
Finally, we can try to run `toyc test/transpose_transpose.toy -emit=mlir -opt`
|
||||
and observe our pattern in action:
|
||||
|
||||
```MLIR(.mlir)
|
||||
func @transpose_transpose(%arg0: !toy<"array">)
|
||||
attributes {toy.generic: true} {
|
||||
%0 = "toy.transpose"(%arg0) : (!toy<"array">) -> !toy<"array">
|
||||
"toy.return"(%arg0) : (!toy<"array">) -> ()
|
||||
}
|
||||
```
|
||||
|
||||
As expected we now directly return the function argument, bypassing any
|
||||
transpose operation. However one of the transpose hasn't been eliminated. That
|
||||
is not ideal! What happened is that our pattern replaced the last transform with
|
||||
the function input and left behind the now dead transpose input. The
|
||||
Canonicalizer knows to cleanup dead operations, however MLIR conservatively
|
||||
assumes that operations may have side-effects. We can fix it by adding a new
|
||||
trait, `HasNoSideEffect`, to our `TransposeOp`:
|
||||
|
||||
```c++
|
||||
class TransposeOp : public mlir::Op<TransposeOp, mlir::OpTrait::OneOperand,
|
||||
mlir::OpTrait::OneResult,
|
||||
mlir::OpTrait::HasNoSideEffect> {
|
||||
```
|
||||
|
||||
Let's retry now `toyc test/transpose_transpose.toy -emit=mlir -opt`:
|
||||
|
||||
```MLIR(.mlir)
|
||||
func @transpose_transpose(%arg0: !toy<"array">)
|
||||
attributes {toy.generic: true} {
|
||||
"toy.return"(%arg0) : (!toy<"array">) -> ()
|
||||
}
|
||||
```
|
||||
|
||||
Perfect! No `transpose` operation is left, the code is optimal.
|
||||
|
||||
The code in `mlir/ToyCombine.cpp` implements a few more patterns that eliminate
|
||||
trivial reshapes, or fold them into constants.
|
||||
|
||||
# Shape Inference and Generic Function Specialization
|
||||
|
||||
Our IR operates on generic arrays, we don't know the shape of the arrays other
|
||||
than during initialization of constants. However we can propagate the shapes
|
||||
through the computation until they are all known. The issue is how to handle
|
||||
calls to user-defined generic functions: every call site could deduce different
|
||||
shapes. One possibility would be to perform symbolic inference based on the
|
||||
argument types, but this would be hard to generalize if we were to introduce
|
||||
more control flow in the language. Instead we will proceed by function
|
||||
specialization: for every call site with new argument shapes we duplicate the
|
||||
function and specialize it. This is akin to C++ template instantiation:
|
||||
|
||||
```
|
||||
template<int M1, int N1, int M2, int N2>
|
||||
auto multiply_add(array<M1, N1> a, array<M1, N1> b) {
|
||||
auto prod = mul(a, b);
|
||||
auto sum = add(prod, a);
|
||||
return sum;
|
||||
}
|
||||
```
|
||||
|
||||
Every new call to `multiply_add` would instantiate the template and emit code
|
||||
for the specific shape and deduce the return type. Clang implements this
|
||||
transformation on its AST, but we will implement it in an MLIR pass here.
|
||||
|
||||
The ShapeInferencePass is a `ModulePass`: it will run on the Module as a whole.
|
||||
MLIR also supports `FunctionPass`es which are restricted to modify a single
|
||||
function at a time. This pass couldn't be a function pass due the nature of its
|
||||
interprocedural transformations.
|
||||
|
||||
Implementing such a pass is done by creating a class inheriting from
|
||||
`mlir::ModulePass` and overriding the `runOnModule()` method:
|
||||
|
||||
```
|
||||
class ShapeInferencePass : public mlir::ModulePass<ShapeInferencePass> {
|
||||
|
||||
void runOnModule() override {
|
||||
auto &module = getModule();
|
||||
...
|
||||
```
|
||||
|
||||
The algorithm has two levels, first intra-procedurally:
|
||||
|
||||
1. Build a worklist containing all the operations that are returning a generic
|
||||
Toy array: these are the operations that need shape inference.
|
||||
2. Iterate on the worklist:
|
||||
- find an operation to process: the next ready operation in the worklist
|
||||
has all of its arguments non-generic,
|
||||
- if no operation is found, break out of the loop,
|
||||
- remove the operation from the worklist,
|
||||
- infer the shape of its output from the arguments type.
|
||||
3. If the worklist is empty, the algorithm succeeded and we infer the return
|
||||
type for the function from the return operation.
|
||||
|
||||
There is a twist though: when a call to a generic function is encountered, shape
|
||||
inference requires the return type of the callee to be inferred first. At this
|
||||
point we need to specialize the callee by cloning it. Here is the
|
||||
inter-procedural flow that wraps the intra-procedural inference:
|
||||
|
||||
1. Keep a worklist of function to process. Start with function "main".
|
||||
2. While the worklist isn't empty:
|
||||
- Take the last inserted function in the worklist.
|
||||
- Run the intra-procedural shape inference on this function.
|
||||
- If the intra-procedural shape inference can't complete, it returns a
|
||||
FuncOp that needs to be inferred first. In this case, queue this new
|
||||
function and continue. Otherwise the inference succeeded and we can pop
|
||||
from the queue.
|
||||
|
||||
The full code is in `mlir/ShapeInferencePass.cpp`.
|
||||
|
||||
# Future Work: Optimizing Buffer Allocation?
|
||||
|
||||
Toy is value-based. Naively this is a lot of allocation, what if we want to
|
||||
statically optimize placement? What is the right abstraction level to perform
|
||||
buffer assignment?
|
||||
** Insert example here **
|
||||
|
|
|
@ -141,6 +141,9 @@ std::unique_ptr<OpPassBase<FuncOp>> createStripDebugInfoPass();
|
|||
/// Creates a pass which tests loop fusion utilities.
|
||||
std::unique_ptr<OpPassBase<FuncOp>> createTestLoopFusionPass();
|
||||
|
||||
/// Creates a pass which inlines calls and callable operations as defined by the
|
||||
/// CallGraph.
|
||||
std::unique_ptr<Pass> createInlinerPass();
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_TRANSFORMS_PASSES_H
|
||||
|
|
|
@ -291,4 +291,8 @@ struct InlinerPass : public OperationPass<InlinerPass> {
|
|||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
std::unique_ptr<Pass> mlir::createInlinerPass() {
|
||||
return std::make_unique<InlinerPass>();
|
||||
}
|
||||
|
||||
static PassRegistration<InlinerPass> pass("inline", "Inline function calls");
|
||||
|
|
|
@ -10,7 +10,7 @@ def main() {
|
|||
# Define a variable `a` with shape <2, 3>, initialized with the literal value.
|
||||
# The shape is inferred from the supplied literal.
|
||||
var a = [[1, 2, 3], [4, 5, 6]];
|
||||
# b is identical to a, the literal array is implicitely reshaped: defining new
|
||||
# b is identical to a, the literal array is implicitly reshaped: defining new
|
||||
# variables is the way to reshape arrays (element count must match).
|
||||
var b<2, 3> = [1, 2, 3, 4, 5, 6];
|
||||
# This call will specialize `multiply_transpose` with <2, 3> for both
|
||||
|
|
|
@ -13,20 +13,19 @@ def main() {
|
|||
print(d);
|
||||
}
|
||||
|
||||
# CHECK-LABEL: func @multiply_transpose(%arg0: !toy.array, %arg1: !toy.array)
|
||||
# CHECK-NEXT: attributes {toy.generic = true} {
|
||||
# CHECK-NEXT: %0 = "toy.transpose"(%arg1) : (!toy.array) -> !toy.array
|
||||
# CHECK-NEXT: %1 = "toy.mul"(%arg0, %0) : (!toy.array, !toy.array) -> !toy.array
|
||||
# CHECK-NEXT: "toy.return"(%1) : (!toy.array) -> ()
|
||||
# CHECK-NEXT: }
|
||||
|
||||
# CHECK-LABEL: func @main() {
|
||||
# CHECK-NEXT: %0 = "toy.constant"() {value = dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> !toy.array<2, 3>
|
||||
# CHECK-NEXT: %1 = "toy.reshape"(%0) : (!toy.array<2, 3>) -> !toy.array<2, 3>
|
||||
# CHECK-NEXT: %2 = "toy.constant"() {value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>} : () -> !toy.array<6>
|
||||
# CHECK-NEXT: %3 = "toy.reshape"(%2) : (!toy.array<6>) -> !toy.array<2, 3>
|
||||
# CHECK-NEXT: %4 = "toy.generic_call"(%1, %3) {callee = "multiply_transpose"} : (!toy.array<2, 3>, !toy.array<2, 3>) -> !toy.array
|
||||
# CHECK-NEXT: %5 = "toy.generic_call"(%3, %1) {callee = "multiply_transpose"} : (!toy.array<2, 3>, !toy.array<2, 3>) -> !toy.array
|
||||
# CHECK-NEXT: "toy.print"(%5) : (!toy.array) -> ()
|
||||
# CHECK-NEXT: "toy.return"() : () -> ()
|
||||
# CHECK-LABEL: func @multiply_transpose(
|
||||
# CHECK-SAME: [[VAL_0:%.*]]: tensor<*xf64>, [[VAL_1:%.*]]: tensor<*xf64>)
|
||||
# CHECK-NEXT: attributes {toy.generic} {
|
||||
# CHECK-NEXT: [[VAL_2:%.*]] = "toy.transpose"([[VAL_1]]) : (tensor<*xf64>) -> tensor<*xf64>
|
||||
# CHECK-NEXT: [[VAL_3:%.*]] = "toy.mul"([[VAL_0]], [[VAL_2]]) : (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64>
|
||||
# CHECK-NEXT: "toy.return"([[VAL_3]]) : (tensor<*xf64>) -> ()
|
||||
|
||||
# CHECK-LABEL: func @main() {
|
||||
# CHECK-NEXT: [[VAL_4:%.*]] = "toy.constant"() {value = dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>
|
||||
# CHECK-NEXT: [[VAL_5:%.*]] = "toy.reshape"([[VAL_4]]) : (tensor<2x3xf64>) -> tensor<2x3xf64>
|
||||
# CHECK-NEXT: [[VAL_6:%.*]] = "toy.constant"() {value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>} : () -> tensor<6xf64>
|
||||
# CHECK-NEXT: [[VAL_7:%.*]] = "toy.reshape"([[VAL_6]]) : (tensor<6xf64>) -> tensor<2x3xf64>
|
||||
# CHECK-NEXT: [[VAL_8:%.*]] = "toy.generic_call"([[VAL_5]], [[VAL_7]]) {callee = @multiply_transpose} : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64>
|
||||
# CHECK-NEXT: [[VAL_9:%.*]] = "toy.generic_call"([[VAL_7]], [[VAL_5]]) {callee = @multiply_transpose} : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64>
|
||||
# CHECK-NEXT: "toy.print"([[VAL_9]]) : (tensor<*xf64>) -> ()
|
||||
# CHECK-NEXT: "toy.return"() : () -> ()
|
||||
|
|
|
@ -1,11 +1,9 @@
|
|||
// RUN: not toyc-ch4 %s -emit=mlir 2>&1
|
||||
|
||||
|
||||
// This IR is not "valid":
|
||||
// The following IR is not "valid":
|
||||
// - toy.print should not return a value.
|
||||
// - toy.print should take an argument.
|
||||
// - There should be a block terminator.
|
||||
// This all round-trip since this is opaque for MLIR.
|
||||
func @main() {
|
||||
%0 = "toy.print"() : () -> !toy.array<2, 3>
|
||||
%0 = "toy.print"() : () -> tensor<2x3xf64>
|
||||
}
|
||||
|
|
|
@ -6,9 +6,9 @@ def main() {
|
|||
}
|
||||
|
||||
# CHECK-LABEL: func @main() {
|
||||
# CHECK-NEXT: %0 = "toy.constant"() {value = dense<5.500000e+00> : tensor<1xf64>} : () -> !toy.array<1>
|
||||
# CHECK-NEXT: %1 = "toy.reshape"(%0) : (!toy.array<1>) -> !toy.array<2, 2>
|
||||
# CHECK-NEXT: "toy.print"(%1) : (!toy.array<2, 2>) -> ()
|
||||
# CHECK-NEXT: %0 = "toy.constant"() {value = dense<5.500000e+00> : tensor<f64>} : () -> tensor<f64>
|
||||
# CHECK-NEXT: %1 = "toy.reshape"(%0) : (tensor<f64>) -> tensor<2x2xf64>
|
||||
# CHECK-NEXT: "toy.print"(%1) : (tensor<2x2xf64>) -> ()
|
||||
# CHECK-NEXT: "toy.return"() : () -> ()
|
||||
# CHECK-NEXT: }
|
||||
|
||||
|
|
|
@ -1,19 +0,0 @@
|
|||
# RUN: toyc-ch4 %s -emit=mlir 2>&1 | FileCheck %s
|
||||
# RUN: toyc-ch4 %s -emit=mlir -opt 2>&1 | FileCheck %s --check-prefix=OPT
|
||||
|
||||
def transpose_transpose(x) {
|
||||
return transpose(transpose(x));
|
||||
}
|
||||
|
||||
def main() {
|
||||
print(transpose_transpose([[1, 2], [3, 4]]));
|
||||
}
|
||||
|
||||
#CHECK-LABEL: func @transpose_transpose
|
||||
#CHECK: transpose
|
||||
#CHECK-LABEL: main
|
||||
|
||||
|
||||
#OPT-LABEL: func @transpose_transpose
|
||||
#OPT-NOT: transpose
|
||||
|
|
@ -1,24 +0,0 @@
|
|||
# RUN: toyc-ch4 %s -emit=mlir 2>&1 | FileCheck %s
|
||||
# RUN: toyc-ch4 %s -emit=mlir -opt 2>&1 | FileCheck %s --check-prefix=OPT
|
||||
|
||||
# We expect no reshape in this function with optimizations enabled
|
||||
def foo(a) {
|
||||
var b<2,1> = a;
|
||||
var c<2,1> = b;
|
||||
print(c);
|
||||
}
|
||||
|
||||
def main() {
|
||||
var a<2, 1> = [1, 2];
|
||||
foo(a);
|
||||
}
|
||||
|
||||
# without optimizations, match the reshape
|
||||
#CHECK-LABEL: func @foo
|
||||
#CHECK: reshape
|
||||
#CHECK-LABEL: main
|
||||
|
||||
# with optimizations, ensure no reshape
|
||||
#OPT-LABEL: main
|
||||
#OPT-LABEL: func @foo_2x1
|
||||
#OPT-NOT: reshape
|
Loading…
Reference in New Issue