forked from OSchip/llvm-project
[MLIR] Sketch a simple set of EDSCs to declaratively write MLIR
This CL introduces a simple set of Embedded Domain-Specific Components (EDSCs) in MLIR components: 1. a `Type` system of shell classes that closely matches the MLIR type system. These types are subdivided into `Bindable` leaf expressions and non-bindable `Expr` expressions; 2. an `MLIREmitter` class whose purpose is to: a. maintain a map of `Bindable` leaf expressions to concrete SSAValue*; b. provide helper functionality to specify bindings of `Bindable` classes to SSAValue* while verifying comformable types; c. traverse the `Expr` and emit the MLIR. This is used on a concrete example to implement MemRef load/store with clipping in the LowerVectorTransfer pass. More specifically, the following pseudo-C++ code: ```c++ MLFuncBuilder *b = ...; Location location = ...; Bindable zero, one, expr, size; // EDSL expression auto access = select(expr < zero, zero, select(expr < size, expr, size - one)); auto ssaValue = MLIREmitter(b) .bind(zero, ...) .bind(one, ...) .bind(expr, ...) .bind(size, ...) .emit(location, access); ``` is used to emit all the MLIR for a clipped MemRef access. This simple EDSL can easily be extended to more powerful patterns and should serve as the counterpart to pattern matchers (and could potentially be unified once we get enough experience). In the future, most of this code should be TableGen'd but for now it has concrete valuable uses: make MLIR programmable in a declarative fashion. This CL also adds Stmt, proper supporting free functions and rewrites VectorTransferLowering fully using EDSCs. The code for creating the EDSCs emitting a VectorTransferReadOp as loops with clipped loads is: ```c++ Stmt block = Block({ tmpAlloc = alloc(tmpMemRefType), vectorView = vector_type_cast(tmpAlloc, vectorMemRefType), ForNest(ivs, lbs, ubs, steps, { scalarValue = load(scalarMemRef, accessInfo.clippedScalarAccessExprs), store(scalarValue, tmpAlloc, accessInfo.tmpAccessExprs), }), vectorValue = load(vectorView, zero), tmpDealloc = dealloc(tmpAlloc.getLHS())}); emitter.emitStmt(block); ``` where `accessInfo.clippedScalarAccessExprs)` is created with: ```c++ select(i + ii < zero, zero, select(i + ii < N, i + ii, N - one)); ``` The generated MLIR resembles: ```mlir %1 = dim %0, 0 : memref<?x?x?x?xf32> %2 = dim %0, 1 : memref<?x?x?x?xf32> %3 = dim %0, 2 : memref<?x?x?x?xf32> %4 = dim %0, 3 : memref<?x?x?x?xf32> %5 = alloc() : memref<5x4x3xf32> %6 = vector_type_cast %5 : memref<5x4x3xf32>, memref<1xvector<5x4x3xf32>> for %i4 = 0 to 3 { for %i5 = 0 to 4 { for %i6 = 0 to 5 { %7 = affine_apply #map0(%i0, %i4) %8 = cmpi "slt", %7, %c0 : index %9 = affine_apply #map0(%i0, %i4) %10 = cmpi "slt", %9, %1 : index %11 = affine_apply #map0(%i0, %i4) %12 = affine_apply #map1(%1, %c1) %13 = select %10, %11, %12 : index %14 = select %8, %c0, %13 : index %15 = affine_apply #map0(%i3, %i6) %16 = cmpi "slt", %15, %c0 : index %17 = affine_apply #map0(%i3, %i6) %18 = cmpi "slt", %17, %4 : index %19 = affine_apply #map0(%i3, %i6) %20 = affine_apply #map1(%4, %c1) %21 = select %18, %19, %20 : index %22 = select %16, %c0, %21 : index %23 = load %0[%14, %i1, %i2, %22] : memref<?x?x?x?xf32> store %23, %5[%i6, %i5, %i4] : memref<5x4x3xf32> } } } %24 = load %6[%c0] : memref<1xvector<5x4x3xf32>> dealloc %5 : memref<5x4x3xf32> ``` In particular notice that only 3 out of the 4-d accesses are clipped: this corresponds indeed to the number of dimensions in the super-vector. This CL also addresses the cleanups resulting from the review of the prevous CL and performs some refactoring to simplify the abstraction. PiperOrigin-RevId: 227367414
This commit is contained in:
parent
a250643ec8
commit
73f5c9c380
|
@ -0,0 +1,126 @@
|
||||||
|
//===- MLIREmitter.h - MLIR EDSC Emitter Class ------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The MLIR Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// Provides a simple interface to bind leaf edsc::Expr to Value* and emit the
|
||||||
|
// corresponding MLIR.
|
||||||
|
//
|
||||||
|
// In a first approximation this EDSC can be viewed as simple helper classes
|
||||||
|
// around MLIR builders. This bears resemblance with Halide but it is more
|
||||||
|
// generally designed to be automatically generated from various IR dialects in
|
||||||
|
// the future.
|
||||||
|
// The implementation is supported by a lightweight by-value abstraction on a
|
||||||
|
// scoped BumpAllocator with similarities to AffineExpr and MLFunctionMatcher.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef MLIR_LIB_EDSC_MLIREMITTER_H_
|
||||||
|
#define MLIR_LIB_EDSC_MLIREMITTER_H_
|
||||||
|
|
||||||
|
#include <tuple>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "mlir/EDSC/Types.h"
|
||||||
|
#include "mlir/IR/Builders.h"
|
||||||
|
#include "mlir/IR/Location.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/DenseMap.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
class FuncBuilder;
|
||||||
|
class Value;
|
||||||
|
|
||||||
|
namespace edsc {
|
||||||
|
|
||||||
|
/// The MLIREmitter class is the supporting abstraction to make arbitrary MLIR
|
||||||
|
/// dialects programmable in a declarative style. As such it is a "generator"
|
||||||
|
/// counterpart of pattern-matchers.
|
||||||
|
/// The purpose of the MLIREmitter is to:
|
||||||
|
/// 1. maintain a map of `Bindable` leaf expressions to concrete Value*;
|
||||||
|
/// 2. provide helper functionality to specify bindings of `Bindable` classes
|
||||||
|
/// to Value* while verifying comformable types;
|
||||||
|
/// 3. traverse the `Expr` and emit the MLIR at the point of insertion of the
|
||||||
|
/// FuncBuilder.
|
||||||
|
struct MLIREmitter {
|
||||||
|
using BindingMap = llvm::DenseMap<Expr, Value *>;
|
||||||
|
|
||||||
|
explicit MLIREmitter(FuncBuilder *builder, Location location)
|
||||||
|
: builder(builder), location(location) {}
|
||||||
|
|
||||||
|
FuncBuilder *getBuilder() { return builder; }
|
||||||
|
Location getLocation() { return location; }
|
||||||
|
|
||||||
|
/// Registers a new binding and type-checks. If a certain Expr type is
|
||||||
|
/// registered, makes sure the Value is of the proper type.
|
||||||
|
MLIREmitter &bind(Bindable e, Value *v);
|
||||||
|
/// Constant values can be created on the spot and bound.
|
||||||
|
template <typename SSAConstantType, typename T>
|
||||||
|
MLIREmitter &bindConstant(Bindable e, T value) {
|
||||||
|
return bind(e, builder->create<SSAConstantType>(location, value));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Registers new bindings and type-checks. If a certain Expr type is
|
||||||
|
/// registered, makes sure the Value is of the proper type.
|
||||||
|
///
|
||||||
|
/// Binds elements one at a time. This may seem inefficient at first glance,
|
||||||
|
/// but each binding is actually type checked.
|
||||||
|
template <typename ZipRangeType>
|
||||||
|
MLIREmitter &bindZipRange(const ZipRangeType &range) {
|
||||||
|
static_assert(std::tuple_size<decltype(range.begin().iterators)>::value ==
|
||||||
|
2,
|
||||||
|
"Need a zip between 2 collections");
|
||||||
|
for (auto it : range) {
|
||||||
|
bind(std::get<0>(it), std::get<1>(it));
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename SSAConstantType, typename ZipRangeType>
|
||||||
|
MLIREmitter &bindZipRangeConstants(const ZipRangeType &range) {
|
||||||
|
static_assert(std::tuple_size<decltype(range.begin().iterators)>::value ==
|
||||||
|
2,
|
||||||
|
"Need a zip between 2 collections");
|
||||||
|
for (auto it : range) {
|
||||||
|
bindConstant<SSAConstantType>(std::get<0>(it), std::get<1>(it));
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Emits the MLIR for `expr` and inserts at the `builder`'s insertion point.
|
||||||
|
/// This function must only be called once on a given emitter.
|
||||||
|
/// Prerequisites: all the Bindables have been bound.
|
||||||
|
Value *emit(Expr expr);
|
||||||
|
llvm::SmallVector<Value *, 8> emit(llvm::ArrayRef<Expr> exprs);
|
||||||
|
|
||||||
|
/// Emits the MLIR for `stmt` and inserts at the `builder`'s insertion point.
|
||||||
|
/// Prerequisites: all the Bindables have been bound.
|
||||||
|
void emitStmt(const Stmt &stmt);
|
||||||
|
void emitStmts(llvm::ArrayRef<Stmt> stmts);
|
||||||
|
|
||||||
|
/// Returns the Value* bound to expr.
|
||||||
|
/// Prerequisite: it must exist.
|
||||||
|
Value *getValue(Expr expr) { return ssaBindings.lookup(expr); }
|
||||||
|
|
||||||
|
private:
|
||||||
|
FuncBuilder *builder;
|
||||||
|
Location location;
|
||||||
|
BindingMap ssaBindings;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace edsc
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // MLIR_LIB_EDSC_MLIREMITTER_H_
|
|
@ -0,0 +1,467 @@
|
||||||
|
//===- Types.h - MLIR EDSC Type System --------------------------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The MLIR Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// Provides a simple value-based type system to implement an EDSC that
|
||||||
|
// simplifies emitting MLIR and future MLIR dialects. Most of this should be
|
||||||
|
// auto-generated in the future.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef MLIR_LIB_EDSC_TYPES_H_
|
||||||
|
#define MLIR_LIB_EDSC_TYPES_H_
|
||||||
|
|
||||||
|
#include "mlir/IR/Types.h"
|
||||||
|
#include "mlir/Support/LLVM.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/DenseMapInfo.h"
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
#include "llvm/ADT/Twine.h"
|
||||||
|
#include "llvm/Support/Casting.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
|
||||||
|
class MLIRContext;
|
||||||
|
|
||||||
|
namespace edsc {
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
struct ExprStorage;
|
||||||
|
struct UnaryExprStorage;
|
||||||
|
struct BinaryExprStorage;
|
||||||
|
struct TernaryExprStorage;
|
||||||
|
struct VariadicExprStorage;
|
||||||
|
|
||||||
|
struct StmtStorage;
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
/// EDSC Types closely mirror the core MLIR and uses an abstraction similar to
|
||||||
|
/// AffineExpr:
|
||||||
|
/// 1. a set of composable structs;
|
||||||
|
/// 2. with by-value semantics everywhere and operator overloading
|
||||||
|
/// 3. with an underlying pointer to impl as payload.
|
||||||
|
/// The vast majority of this code should be TableGen'd in the future which
|
||||||
|
/// would allow us to automatically emit an EDSC for any IR dialect we are
|
||||||
|
/// interested in. In turn this makes any IR dialect fully programmable in a
|
||||||
|
/// declarative fashion.
|
||||||
|
///
|
||||||
|
/// The main differences with the AffineExpr design are as follows:
|
||||||
|
/// 1. this type-system is an empty shell to which we can lazily bind Value*
|
||||||
|
/// at the moment of emitting MLIR;
|
||||||
|
/// 2. the data structures are BumpPointer allocated in a global
|
||||||
|
/// `ScopedEDSCContext` with scoped lifetime. This allows avoiding to
|
||||||
|
/// pass and store an extra Context pointer around and keeps users honest:
|
||||||
|
/// *this is absolutely not meant to escape a local scope*.
|
||||||
|
///
|
||||||
|
/// The decision of slicing the underlying IR types into Bindable and
|
||||||
|
/// NonBindable types is flexible and influences programmability.
|
||||||
|
enum class ExprKind {
|
||||||
|
FIRST_BINDABLE_EXPR = 100,
|
||||||
|
Unbound = FIRST_BINDABLE_EXPR,
|
||||||
|
LAST_BINDABLE_EXPR = Unbound,
|
||||||
|
FIRST_NON_BINDABLE_EXPR = 200,
|
||||||
|
FIRST_UNARY_EXPR = FIRST_NON_BINDABLE_EXPR,
|
||||||
|
Dealloc = FIRST_UNARY_EXPR,
|
||||||
|
Negate,
|
||||||
|
LAST_UNARY_EXPR = Negate,
|
||||||
|
FIRST_BINARY_EXPR = 300,
|
||||||
|
Add = FIRST_BINARY_EXPR,
|
||||||
|
Sub,
|
||||||
|
Mul,
|
||||||
|
Div,
|
||||||
|
AddEQ,
|
||||||
|
SubEQ,
|
||||||
|
MulEQ,
|
||||||
|
DivEQ,
|
||||||
|
GE,
|
||||||
|
GT,
|
||||||
|
LE,
|
||||||
|
LT,
|
||||||
|
EQ,
|
||||||
|
NE,
|
||||||
|
And,
|
||||||
|
Or,
|
||||||
|
LAST_BINARY_EXPR = Or,
|
||||||
|
FIRST_TERNARY_EXPR = 400,
|
||||||
|
Select = FIRST_TERNARY_EXPR,
|
||||||
|
IfThenElse,
|
||||||
|
LAST_TERNARY_EXPR = IfThenElse,
|
||||||
|
FIRST_VARIADIC_EXPR = 500,
|
||||||
|
Alloc = FIRST_VARIADIC_EXPR, // Variadic because takes multiple dynamic shape
|
||||||
|
// values.
|
||||||
|
Load,
|
||||||
|
Store,
|
||||||
|
VectorTypeCast, // Variadic because takes a type and anything taking a type
|
||||||
|
// is variadic for now.
|
||||||
|
LAST_VARIADIC_EXPR = VectorTypeCast,
|
||||||
|
FIRST_STMT_BLOCK_LIKE_EXPR = 600,
|
||||||
|
Block = FIRST_STMT_BLOCK_LIKE_EXPR,
|
||||||
|
For,
|
||||||
|
LAST_STMT_BLOCK_LIKE_EXPR = For,
|
||||||
|
LAST_NON_BINDABLE_EXPR = LAST_STMT_BLOCK_LIKE_EXPR,
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Scoped context holding a BumpPtrAllocator.
|
||||||
|
/// Creating such an object injects a new allocator in Expr::globalAllocator.
|
||||||
|
/// At the moment we can have only have one such context.
|
||||||
|
///
|
||||||
|
/// Usage:
|
||||||
|
///
|
||||||
|
/// ```c++
|
||||||
|
/// MLFunctionBuilder *b = ...;
|
||||||
|
/// Location someLocation = ...;
|
||||||
|
/// Value *zeroValue = ...;
|
||||||
|
/// Value *oneValue = ...;
|
||||||
|
///
|
||||||
|
/// ScopedEDSCContext raiiContext;
|
||||||
|
/// Constant zero, one;
|
||||||
|
/// Value *val = MLIREmitter(b)
|
||||||
|
/// .bind(zero, zeroValue)
|
||||||
|
/// .bind(one, oneValue)
|
||||||
|
/// .emit(someLocation, zero + one);
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// will emit MLIR resembling:
|
||||||
|
///
|
||||||
|
/// ```mlir
|
||||||
|
/// %2 = add(%c0, %c1) : index
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// The point of the EDSC is to synthesize arbitrarily more complex patterns in
|
||||||
|
/// a declarative fashion. For example, clipping for guaranteed in-bounds access
|
||||||
|
/// can be written:
|
||||||
|
///
|
||||||
|
/// ```c++
|
||||||
|
/// auto expr = select(expr < 0, 0, select(expr < size, expr, size - 1));
|
||||||
|
/// Value *val = MLIREmitter(b).bind(...).emit(loc, expr);
|
||||||
|
/// ```
|
||||||
|
struct ScopedEDSCContext {
|
||||||
|
ScopedEDSCContext();
|
||||||
|
~ScopedEDSCContext();
|
||||||
|
llvm::BumpPtrAllocator allocator;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Expr {
|
||||||
|
public:
|
||||||
|
using ImplType = detail::ExprStorage;
|
||||||
|
|
||||||
|
/// Returns the scoped BumpPtrAllocator. This must be done in the context of a
|
||||||
|
/// unique `ScopedEDSCContext` declared in an RAII fashion in some enclosing
|
||||||
|
/// scope.
|
||||||
|
static llvm::BumpPtrAllocator *&globalAllocator() {
|
||||||
|
static thread_local llvm::BumpPtrAllocator *allocator = nullptr;
|
||||||
|
return allocator;
|
||||||
|
}
|
||||||
|
|
||||||
|
Expr() : storage(nullptr) {}
|
||||||
|
/* implicit */ Expr(ImplType *storage) : storage(storage) {}
|
||||||
|
|
||||||
|
Expr(const Expr &other) : storage(other.storage) {}
|
||||||
|
Expr &operator=(Expr other) {
|
||||||
|
storage = other.storage;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
explicit operator bool() { return storage; }
|
||||||
|
bool operator!() { return storage == nullptr; }
|
||||||
|
|
||||||
|
template <typename U> bool isa() const;
|
||||||
|
template <typename U> U dyn_cast() const;
|
||||||
|
template <typename U> U cast() const;
|
||||||
|
|
||||||
|
MLIRContext *getContext() const;
|
||||||
|
|
||||||
|
/// Returns the classification for this type.
|
||||||
|
ExprKind getKind() const;
|
||||||
|
|
||||||
|
void print(raw_ostream &os) const;
|
||||||
|
void dump() const;
|
||||||
|
|
||||||
|
/// Creates the BinaryExpr corresponding to the operator.
|
||||||
|
Expr operator+(Expr other) const;
|
||||||
|
Expr operator-(Expr other) const;
|
||||||
|
/// In particular operator==, operator!= return a new Expr and *not* a bool.
|
||||||
|
Expr operator==(Expr other) const;
|
||||||
|
Expr operator!=(Expr other) const;
|
||||||
|
Expr operator<(Expr other) const;
|
||||||
|
Expr operator<=(Expr other) const;
|
||||||
|
Expr operator>(Expr other) const;
|
||||||
|
Expr operator>=(Expr other) const;
|
||||||
|
Expr operator&&(Expr other) const;
|
||||||
|
Expr operator||(Expr other) const;
|
||||||
|
|
||||||
|
/// For debugging purposes.
|
||||||
|
const void *getStoragePtr() const { return storage; }
|
||||||
|
|
||||||
|
friend ::llvm::hash_code hash_value(Expr arg);
|
||||||
|
|
||||||
|
protected:
|
||||||
|
ImplType *storage;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Bindable : public Expr {
|
||||||
|
using ImplType = detail::ExprStorage;
|
||||||
|
friend class Expr;
|
||||||
|
Bindable(ExprKind kind = ExprKind::Unbound);
|
||||||
|
|
||||||
|
protected:
|
||||||
|
Bindable(Expr::ImplType *ptr) : Expr(ptr) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct UnaryExpr : public Expr {
|
||||||
|
using ImplType = detail::UnaryExprStorage;
|
||||||
|
friend class Expr;
|
||||||
|
|
||||||
|
UnaryExpr(ExprKind kind, Expr expr);
|
||||||
|
Expr getExpr() const;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
UnaryExpr(Expr::ImplType *ptr) : Expr(ptr) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct BinaryExpr : public Expr {
|
||||||
|
using ImplType = detail::BinaryExprStorage;
|
||||||
|
friend class Expr;
|
||||||
|
BinaryExpr(ExprKind kind, Expr lhs, Expr rhs);
|
||||||
|
Expr getLHS() const;
|
||||||
|
Expr getRHS() const;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
BinaryExpr(Expr::ImplType *ptr) : Expr(ptr) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct TernaryExpr : public Expr {
|
||||||
|
using ImplType = detail::TernaryExprStorage;
|
||||||
|
friend class Expr;
|
||||||
|
TernaryExpr(ExprKind kind, Expr cond, Expr lhs, Expr rhs);
|
||||||
|
Expr getCond() const;
|
||||||
|
Expr getLHS() const;
|
||||||
|
Expr getRHS() const;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
TernaryExpr(Expr::ImplType *ptr) : Expr(ptr) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct VariadicExpr : public Expr {
|
||||||
|
using ImplType = detail::VariadicExprStorage;
|
||||||
|
friend class Expr;
|
||||||
|
VariadicExpr(ExprKind kind, llvm::ArrayRef<Expr> exprs,
|
||||||
|
llvm::ArrayRef<Type> types = {});
|
||||||
|
llvm::ArrayRef<Expr> getExprs() const;
|
||||||
|
llvm::ArrayRef<Type> getTypes() const;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
VariadicExpr(Expr::ImplType *ptr) : Expr(ptr) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct StmtBlockLikeExpr : public VariadicExpr {
|
||||||
|
using ImplType = detail::VariadicExprStorage;
|
||||||
|
friend class Expr;
|
||||||
|
StmtBlockLikeExpr(ExprKind kind, llvm::ArrayRef<Expr> exprs,
|
||||||
|
llvm::ArrayRef<Type> types = {})
|
||||||
|
: VariadicExpr(kind, exprs, types) {}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
StmtBlockLikeExpr(Expr::ImplType *ptr) : VariadicExpr(ptr) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
/// A Stmt represent a unit of liaison betweeb a Bindable `lhs`, an Expr `rhs`
|
||||||
|
/// and a list of `enclosingStmts`. This essentially allows giving a name and a
|
||||||
|
/// scoping to objects of type `Expr` so they can be reused once bound to an
|
||||||
|
/// Value*. This enables writing generators such as:
|
||||||
|
///
|
||||||
|
/// ```mlir
|
||||||
|
/// Stmt scalarValue, vectorValue, tmpAlloc, tmpDealloc, vectorView;
|
||||||
|
/// Stmt block = Block({
|
||||||
|
/// tmpAlloc = alloc(tmpMemRefType),
|
||||||
|
/// vectorView = vector_type_cast(tmpAlloc, vectorMemRefType),
|
||||||
|
/// ForNest(ivs, lbs, ubs, steps, {
|
||||||
|
/// scalarValue = load(scalarMemRef,
|
||||||
|
/// accessInfo.clippedScalarAccessExprs), store(scalarValue, tmpAlloc,
|
||||||
|
/// accessInfo.tmpAccessExprs),
|
||||||
|
/// }),
|
||||||
|
/// vectorValue = load(vectorView, zero),
|
||||||
|
/// tmpDealloc = dealloc(tmpAlloc.getLHS())});
|
||||||
|
/// emitter.emitStmt(block);
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// A Stmt can be declared with either:
|
||||||
|
/// 1. default initialization (e.g. `Stmt foo;`) in which case all of its `lhs`,
|
||||||
|
/// `rhs` and `enclosingStmts` are unbound;
|
||||||
|
/// 2. initialization from an Expr without a Bindable `lhs`
|
||||||
|
/// (e.g. store(scalarValue, tmpAlloc, accessInfo.tmpAccessExprs)), in which
|
||||||
|
/// case the `lhs` is unbound;
|
||||||
|
/// 3. an assignment operator to a `lhs` Stmt that is bound implicitly:
|
||||||
|
/// (e.g. vectorValue = load(vectorView, zero)).
|
||||||
|
///
|
||||||
|
/// Only ExprKind::StmtBlockLikeExpr have `enclosedStmts`, these comprise:
|
||||||
|
/// 1. `For`-loops for which the `lhs` binds to the induction variable, `rhs`
|
||||||
|
/// binds to an Expr of kind `ExprKind::For` with lower-bound, upper-bound and
|
||||||
|
/// step respectively;
|
||||||
|
/// 2. `Block` with an Expr of kind `ExprKind::Block` and which has no `rhs` but
|
||||||
|
/// only `enclosingStmts`.
|
||||||
|
struct Stmt {
|
||||||
|
using ImplType = detail::StmtStorage;
|
||||||
|
friend class Expr;
|
||||||
|
Stmt() : storage(nullptr) {}
|
||||||
|
Stmt(const Stmt &other) : storage(other.storage) {}
|
||||||
|
Stmt operator=(const Stmt &other) {
|
||||||
|
this->storage = other.storage; // NBD if &other == this
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
Stmt(const Expr &rhs, llvm::ArrayRef<Stmt> stmts = llvm::ArrayRef<Stmt>());
|
||||||
|
Stmt(const Bindable &lhs, const Expr &rhs,
|
||||||
|
llvm::ArrayRef<Stmt> stmts = llvm::ArrayRef<Stmt>());
|
||||||
|
Stmt &operator=(const Expr &expr);
|
||||||
|
|
||||||
|
operator Expr() const { return getLHS(); }
|
||||||
|
|
||||||
|
/// For debugging purposes.
|
||||||
|
const void *getStoragePtr() const { return storage; }
|
||||||
|
|
||||||
|
void print(raw_ostream &os, llvm::Twine indent = "") const;
|
||||||
|
void dump() const;
|
||||||
|
|
||||||
|
Bindable getLHS() const;
|
||||||
|
Expr getRHS() const;
|
||||||
|
llvm::ArrayRef<Stmt> getEnclosedStmts() const;
|
||||||
|
|
||||||
|
Expr operator+(Stmt other) const { return getLHS() + other.getLHS(); }
|
||||||
|
Expr operator-(Stmt other) const { return getLHS() - other.getLHS(); }
|
||||||
|
|
||||||
|
Expr operator<(Stmt other) const { return getLHS() + other.getLHS(); }
|
||||||
|
Expr operator<=(Stmt other) const { return getLHS() + other.getLHS(); }
|
||||||
|
Expr operator>(Stmt other) const { return getLHS() + other.getLHS(); }
|
||||||
|
Expr operator>=(Stmt other) const { return getLHS() + other.getLHS(); }
|
||||||
|
Expr operator&&(Stmt other) const { return getLHS() + other.getLHS(); }
|
||||||
|
Expr operator||(Stmt other) const { return getLHS() + other.getLHS(); }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
ImplType *storage;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename U> bool Expr::isa() const {
|
||||||
|
auto kind = getKind();
|
||||||
|
if (std::is_same<U, Bindable>::value) {
|
||||||
|
return kind >= ExprKind::FIRST_BINDABLE_EXPR &&
|
||||||
|
kind <= ExprKind::LAST_BINDABLE_EXPR;
|
||||||
|
}
|
||||||
|
if (std::is_same<U, UnaryExpr>::value) {
|
||||||
|
return kind >= ExprKind::FIRST_UNARY_EXPR &&
|
||||||
|
kind <= ExprKind::LAST_UNARY_EXPR;
|
||||||
|
}
|
||||||
|
if (std::is_same<U, BinaryExpr>::value) {
|
||||||
|
return kind >= ExprKind::FIRST_BINARY_EXPR &&
|
||||||
|
kind <= ExprKind::LAST_BINARY_EXPR;
|
||||||
|
}
|
||||||
|
if (std::is_same<U, TernaryExpr>::value) {
|
||||||
|
return kind >= ExprKind::FIRST_TERNARY_EXPR &&
|
||||||
|
kind <= ExprKind::LAST_TERNARY_EXPR;
|
||||||
|
}
|
||||||
|
if (std::is_same<U, VariadicExpr>::value) {
|
||||||
|
return kind >= ExprKind::FIRST_VARIADIC_EXPR &&
|
||||||
|
kind <= ExprKind::LAST_VARIADIC_EXPR;
|
||||||
|
}
|
||||||
|
if (std::is_same<U, StmtBlockLikeExpr>::value) {
|
||||||
|
return kind >= ExprKind::FIRST_STMT_BLOCK_LIKE_EXPR &&
|
||||||
|
kind <= ExprKind::LAST_STMT_BLOCK_LIKE_EXPR;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename U> U Expr::dyn_cast() const {
|
||||||
|
if (isa<U>()) {
|
||||||
|
return U(storage);
|
||||||
|
}
|
||||||
|
return U(nullptr);
|
||||||
|
}
|
||||||
|
template <typename U> U Expr::cast() const {
|
||||||
|
assert(isa<U>());
|
||||||
|
return U(storage);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Make Expr hashable.
|
||||||
|
inline ::llvm::hash_code hash_value(Expr arg) {
|
||||||
|
return ::llvm::hash_value(arg.storage);
|
||||||
|
}
|
||||||
|
|
||||||
|
raw_ostream &operator<<(raw_ostream &os, const Expr &expr);
|
||||||
|
raw_ostream &operator<<(raw_ostream &os, const Stmt &stmt);
|
||||||
|
|
||||||
|
} // namespace edsc
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
namespace llvm {
|
||||||
|
|
||||||
|
// Expr hash just like pointers
|
||||||
|
template <> struct DenseMapInfo<mlir::edsc::Expr> {
|
||||||
|
static mlir::edsc::Expr getEmptyKey() {
|
||||||
|
auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
|
||||||
|
return mlir::edsc::Expr(static_cast<mlir::edsc::Expr::ImplType *>(pointer));
|
||||||
|
}
|
||||||
|
static mlir::edsc::Expr getTombstoneKey() {
|
||||||
|
auto pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
|
||||||
|
return mlir::edsc::Expr(static_cast<mlir::edsc::Expr::ImplType *>(pointer));
|
||||||
|
}
|
||||||
|
static unsigned getHashValue(mlir::edsc::Expr val) {
|
||||||
|
return mlir::edsc::hash_value(val);
|
||||||
|
}
|
||||||
|
static bool isEqual(mlir::edsc::Expr LHS, mlir::edsc::Expr RHS) {
|
||||||
|
return LHS.getStoragePtr() == RHS.getStoragePtr();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace llvm
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace edsc {
|
||||||
|
|
||||||
|
/// Free function sugar.
|
||||||
|
///
|
||||||
|
/// Since bindings are hashed by the underlying pointer address, we need to be
|
||||||
|
/// sure to construct new elements in a vector. We cannot just use
|
||||||
|
/// `llvm::SmallVector<Bindable, 8> dims(n);` directly because a single
|
||||||
|
/// `Bindable` will be default constructed and copied everywhere in the vector.
|
||||||
|
/// Hilarity ensues when trying to bind structs that are already bound.
|
||||||
|
llvm::SmallVector<Bindable, 8> makeBindables(unsigned n);
|
||||||
|
llvm::SmallVector<Expr, 8> makeExprs(unsigned n);
|
||||||
|
llvm::SmallVector<Expr, 8> makeExprs(ArrayRef<Bindable> bindables);
|
||||||
|
template <typename IterTy>
|
||||||
|
llvm::SmallVector<Expr, 8> makeExprs(IterTy begin, IterTy end) {
|
||||||
|
return llvm::SmallVector<Expr, 8>(begin, end);
|
||||||
|
}
|
||||||
|
|
||||||
|
Expr alloc(llvm::ArrayRef<Expr> sizes, Type memrefType);
|
||||||
|
inline Expr alloc(Type memrefType) { return alloc({}, memrefType); }
|
||||||
|
Expr dealloc(Expr memref);
|
||||||
|
Expr load(Expr m, llvm::ArrayRef<Expr> indices);
|
||||||
|
Expr store(Expr val, Expr m, llvm::ArrayRef<Expr> indices);
|
||||||
|
Expr select(Expr cond, Expr lhs, Expr rhs);
|
||||||
|
Expr vector_type_cast(Expr memrefExpr, Type memrefType);
|
||||||
|
|
||||||
|
Stmt Block(ArrayRef<Stmt> stmts);
|
||||||
|
Stmt For(Expr lb, Expr ub, Expr step, ArrayRef<Stmt> enclosedStmts);
|
||||||
|
Stmt For(const Bindable &idx, Expr lb, Expr ub, Expr step,
|
||||||
|
ArrayRef<Stmt> enclosedStmts);
|
||||||
|
Stmt ForNest(MutableArrayRef<Bindable> indices, ArrayRef<Expr> lbs,
|
||||||
|
ArrayRef<Expr> ubs, ArrayRef<Expr> steps,
|
||||||
|
ArrayRef<Stmt> enclosedStmts);
|
||||||
|
|
||||||
|
} // namespace edsc
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // MLIR_LIB_EDSC_TYPES_H_
|
|
@ -104,7 +104,7 @@ template <typename OpClass> struct op_matcher {
|
||||||
|
|
||||||
} // end namespace detail
|
} // end namespace detail
|
||||||
|
|
||||||
/// Entry point for matching a pattern over an SSAValue.
|
/// Entry point for matching a pattern over a Value.
|
||||||
template <typename Pattern>
|
template <typename Pattern>
|
||||||
inline bool matchPattern(Value *value, const Pattern &pattern) {
|
inline bool matchPattern(Value *value, const Pattern &pattern) {
|
||||||
// TODO: handle other cases
|
// TODO: handle other cases
|
||||||
|
|
|
@ -31,11 +31,11 @@ namespace mlir {
|
||||||
namespace functional {
|
namespace functional {
|
||||||
|
|
||||||
/// Map with iterators.
|
/// Map with iterators.
|
||||||
template <typename Fun, typename IterType>
|
template <typename Fn, typename IterType>
|
||||||
auto map(Fun fun, IterType begin, IterType end)
|
auto map(Fn fun, IterType begin, IterType end)
|
||||||
-> llvm::SmallVector<typename std::result_of<Fun(decltype(*begin))>::type,
|
-> llvm::SmallVector<typename std::result_of<Fn(decltype(*begin))>::type,
|
||||||
8> {
|
8> {
|
||||||
using R = typename std::result_of<Fun(decltype(*begin))>::type;
|
using R = typename std::result_of<Fn(decltype(*begin))>::type;
|
||||||
llvm::SmallVector<R, 8> res;
|
llvm::SmallVector<R, 8> res;
|
||||||
// auto i works with both pointer types and value types with an operator*.
|
// auto i works with both pointer types and value types with an operator*.
|
||||||
// auto *i only works for pointer types.
|
// auto *i only works for pointer types.
|
||||||
|
@ -46,15 +46,34 @@ auto map(Fun fun, IterType begin, IterType end)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Map with templated container.
|
/// Map with templated container.
|
||||||
template <typename Fun, typename ContainerType>
|
template <typename Fn, typename ContainerType>
|
||||||
auto map(Fun fun, ContainerType input)
|
auto map(Fn fun, ContainerType input)
|
||||||
-> decltype(map(fun, std::begin(input), std::end(input))) {
|
-> decltype(map(fun, std::begin(input), std::end(input))) {
|
||||||
return map(fun, std::begin(input), std::end(input));
|
return map(fun, std::begin(input), std::end(input));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Zip map with 2 templated container, iterates to the min of the sizes of
|
||||||
|
/// the 2 containers.
|
||||||
|
/// TODO(ntv): make variadic when needed.
|
||||||
|
template <typename Fn, typename ContainerType1, typename ContainerType2>
|
||||||
|
auto zipMap(Fn fun, ContainerType1 input1, ContainerType2 input2)
|
||||||
|
-> llvm::SmallVector<
|
||||||
|
typename std::result_of<Fn(decltype(*input1.begin()),
|
||||||
|
decltype(*input2.begin()))>::type,
|
||||||
|
8> {
|
||||||
|
using R = typename std::result_of<Fn(decltype(*input1.begin()),
|
||||||
|
decltype(*input2.begin()))>::type;
|
||||||
|
llvm::SmallVector<R, 8> res;
|
||||||
|
auto zipIter = llvm::zip(input1, input2);
|
||||||
|
for (auto it : zipIter) {
|
||||||
|
res.push_back(fun(std::get<0>(it), std::get<1>(it)));
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
/// Apply with iterators.
|
/// Apply with iterators.
|
||||||
template <typename Fun, typename IterType>
|
template <typename Fn, typename IterType>
|
||||||
void apply(Fun fun, IterType begin, IterType end) {
|
void apply(Fn fun, IterType begin, IterType end) {
|
||||||
// auto i works with both pointer types and value types with an operator*.
|
// auto i works with both pointer types and value types with an operator*.
|
||||||
// auto *i only works for pointer types.
|
// auto *i only works for pointer types.
|
||||||
for (auto i = begin; i != end; ++i) {
|
for (auto i = begin; i != end; ++i) {
|
||||||
|
@ -63,16 +82,16 @@ void apply(Fun fun, IterType begin, IterType end) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Apply with templated container.
|
/// Apply with templated container.
|
||||||
template <typename Fun, typename ContainerType>
|
template <typename Fn, typename ContainerType>
|
||||||
void apply(Fun fun, ContainerType input) {
|
void apply(Fn fun, ContainerType input) {
|
||||||
return apply(fun, std::begin(input), std::end(input));
|
return apply(fun, std::begin(input), std::end(input));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Zip apply with 2 templated container, iterates to the min of the sizes of
|
/// Zip apply with 2 templated container, iterates to the min of the sizes of
|
||||||
/// the 2 containers.
|
/// the 2 containers.
|
||||||
/// TODO(ntv): make variadic.
|
/// TODO(ntv): make variadic when needed.
|
||||||
template <typename Fun, typename ContainerType1, typename ContainerType2>
|
template <typename Fn, typename ContainerType1, typename ContainerType2>
|
||||||
void zipApply(Fun fun, ContainerType1 input1, ContainerType2 input2) {
|
void zipApply(Fn fun, ContainerType1 input1, ContainerType2 input2) {
|
||||||
auto zipIter = llvm::zip(input1, input2);
|
auto zipIter = llvm::zip(input1, input2);
|
||||||
for (auto it : zipIter) {
|
for (auto it : zipIter) {
|
||||||
fun(std::get<0>(it), std::get<1>(it));
|
fun(std::get<0>(it), std::get<1>(it));
|
||||||
|
|
|
@ -0,0 +1,330 @@
|
||||||
|
//===- MLIREmitter.cpp - MLIR EDSC Emitter Class Implementation -*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The MLIR Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/Support/Casting.h"
|
||||||
|
#include "llvm/Support/Debug.h"
|
||||||
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
|
||||||
|
#include "mlir/EDSC/MLIREmitter.h"
|
||||||
|
#include "mlir/EDSC/Types.h"
|
||||||
|
#include "mlir/IR/Builders.h"
|
||||||
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
|
#include "mlir/IR/Instructions.h"
|
||||||
|
#include "mlir/IR/Location.h"
|
||||||
|
#include "mlir/IR/Value.h"
|
||||||
|
#include "mlir/StandardOps/StandardOps.h"
|
||||||
|
#include "mlir/SuperVectorOps/SuperVectorOps.h"
|
||||||
|
#include "mlir/Support/Functional.h"
|
||||||
|
#include "mlir/Support/STLExtras.h"
|
||||||
|
|
||||||
|
using llvm::dbgs;
|
||||||
|
using llvm::errs;
|
||||||
|
|
||||||
|
#define DEBUG_TYPE "edsc"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace edsc {
|
||||||
|
|
||||||
|
// Factors out the boilerplate that is needed to build and answer the
|
||||||
|
// following simple question:
|
||||||
|
// Given a set of Value* `values`, how do I get the resulting op(`values`)
|
||||||
|
//
|
||||||
|
// This is a very loaded question and generally cannot be answered properly.
|
||||||
|
// For instance, an LLVM operation has many attributes that may not fit within
|
||||||
|
// this simplistic framing (e.g. overflow behavior etc).
|
||||||
|
//
|
||||||
|
// Still, MLIR is a higher-level IR and the Halide experience shows it is
|
||||||
|
// possible to build useful EDSCs with the right amount of sugar.
|
||||||
|
//
|
||||||
|
// To build EDSCs we need to be able to conveniently support simple operations
|
||||||
|
// such as `add` on the type system. This captures the possible behaviors. In
|
||||||
|
// the future, this should be automatically constructed from an abstraction
|
||||||
|
// that is common to the IR verifier, but for now we need to get off the ground
|
||||||
|
// manually.
|
||||||
|
//
|
||||||
|
// This is expected to be a "dialect-specific" functionality: certain dialects
|
||||||
|
// will not have a simple definition. Two such cases that come to mind are:
|
||||||
|
// 1. what does it mean to have an operator* on an opaque tensor dialect
|
||||||
|
// (dot, vector, hadamard, kronecker ?)-product;
|
||||||
|
// 2. LLVM add with attributes like overflow.
|
||||||
|
// This is all left for future consideration; in the meantime let's separate
|
||||||
|
// concerns and implement useful infrastructure without solving all problems at
|
||||||
|
// once.
|
||||||
|
|
||||||
|
/// Returns the element type if the type is VectorType or MemRefType; returns
|
||||||
|
/// getType if the type is scalar.
|
||||||
|
static Type getElementType(const Value &v) {
|
||||||
|
if (auto vec = v.getType().dyn_cast<mlir::VectorType>()) {
|
||||||
|
return vec.getElementType();
|
||||||
|
}
|
||||||
|
if (auto mem = v.getType().dyn_cast<mlir::MemRefType>()) {
|
||||||
|
return mem.getElementType();
|
||||||
|
}
|
||||||
|
return v.getType();
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool isIndexElement(const Value &v) {
|
||||||
|
return getElementType(v).isIndex();
|
||||||
|
}
|
||||||
|
static bool isIntElement(const Value &v) {
|
||||||
|
return getElementType(v).isa<IntegerType>();
|
||||||
|
}
|
||||||
|
static bool isFloatElement(const Value &v) {
|
||||||
|
return getElementType(v).isa<FloatType>();
|
||||||
|
}
|
||||||
|
|
||||||
|
Value *add(FuncBuilder *builder, Location location, Value *a, Value *b) {
|
||||||
|
if (isIndexElement(*a)) {
|
||||||
|
auto *context = builder->getContext();
|
||||||
|
auto d0 = getAffineDimExpr(0, context);
|
||||||
|
auto d1 = getAffineDimExpr(1, context);
|
||||||
|
auto map = AffineMap::get(2, 0, {d0 + d1}, {});
|
||||||
|
return builder
|
||||||
|
->create<AffineApplyOp>(location, map, ArrayRef<Value *>{a, b})
|
||||||
|
->getResult(0);
|
||||||
|
} else if (isIntElement(*a)) {
|
||||||
|
return builder->create<AddIOp>(location, a, b)->getResult();
|
||||||
|
}
|
||||||
|
assert(isFloatElement(*a) && "Expected float element");
|
||||||
|
return builder->create<AddFOp>(location, a, b)->getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
Value *sub(FuncBuilder *builder, Location location, Value *a, Value *b) {
|
||||||
|
if (isIndexElement(*a)) {
|
||||||
|
auto *context = builder->getContext();
|
||||||
|
auto d0 = getAffineDimExpr(0, context);
|
||||||
|
auto d1 = getAffineDimExpr(1, context);
|
||||||
|
auto map = AffineMap::get(2, 0, {d0 - d1}, {});
|
||||||
|
return builder
|
||||||
|
->create<AffineApplyOp>(location, map, ArrayRef<Value *>{a, b})
|
||||||
|
->getResult(0);
|
||||||
|
} else if (isIntElement(*a)) {
|
||||||
|
return builder->create<SubIOp>(location, a, b)->getResult();
|
||||||
|
}
|
||||||
|
assert(isFloatElement(*a) && "Expected float element");
|
||||||
|
return builder->create<SubFOp>(location, a, b)->getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
Value *mul(FuncBuilder *builder, Location location, Value *a, Value *b) {
|
||||||
|
if (!isFloatElement(*a)) {
|
||||||
|
return builder->create<MulIOp>(location, a, b)->getResult();
|
||||||
|
}
|
||||||
|
assert(isFloatElement(*a) && "Expected float element");
|
||||||
|
return builder->create<MulFOp>(location, a, b)->getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
static void printDefininingStatement(llvm::raw_ostream &os, const Value &v) {
|
||||||
|
const auto *inst = v.getDefiningInst();
|
||||||
|
if (inst) {
|
||||||
|
inst->print(os);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// &v is required here otherwise we get:
|
||||||
|
// non-pointer operand type 'const mlir::ForInst' incompatible with nullptr
|
||||||
|
if (auto *forInst = dyn_cast<ForInst>(&v)) {
|
||||||
|
forInst->print(os);
|
||||||
|
} else {
|
||||||
|
os << "unknown_ssa_value";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
MLIREmitter &MLIREmitter::bind(Bindable e, Value *v) {
|
||||||
|
LLVM_DEBUG(printDefininingStatement(llvm::dbgs() << "\nBinding " << e << " @"
|
||||||
|
<< e.getStoragePtr() << ": ",
|
||||||
|
*v));
|
||||||
|
auto it = ssaBindings.insert(std::make_pair(e, v));
|
||||||
|
if (!it.second) {
|
||||||
|
printDefininingStatement(
|
||||||
|
llvm::errs() << "\nRebinding " << e << " @" << e.getStoragePtr(), *v);
|
||||||
|
llvm_unreachable("Double binding!");
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
Value *MLIREmitter::emit(Expr e) {
|
||||||
|
auto it = ssaBindings.find(e);
|
||||||
|
if (it != ssaBindings.end()) {
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip bindables, they must have been found already.
|
||||||
|
Value *res = nullptr;
|
||||||
|
if (auto un = e.dyn_cast<UnaryExpr>()) {
|
||||||
|
if (un.getKind() == ExprKind::Dealloc) {
|
||||||
|
builder->create<DeallocOp>(location, emit(un.getExpr()));
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
} else if (auto bin = e.dyn_cast<BinaryExpr>()) {
|
||||||
|
auto *a = emit(bin.getLHS());
|
||||||
|
auto *b = emit(bin.getRHS());
|
||||||
|
if (bin.getKind() == ExprKind::Add) {
|
||||||
|
res = add(builder, location, a, b);
|
||||||
|
} else if (bin.getKind() == ExprKind::Sub) {
|
||||||
|
res = sub(builder, location, a, b);
|
||||||
|
} else if (bin.getKind() == ExprKind::Mul) {
|
||||||
|
res = mul(builder, location, a, b);
|
||||||
|
}
|
||||||
|
// Vanilla comparisons operators.
|
||||||
|
// else if (bin.getKind() == ExprKind::And) {
|
||||||
|
// // impl i1
|
||||||
|
// res = add(builder, location, a, b); // MulIOp on i1
|
||||||
|
// }
|
||||||
|
// else if (bin.getKind() == ExprKind::Not) {
|
||||||
|
// res = ...; // 1 - cast<i1>()
|
||||||
|
// }
|
||||||
|
// else if (bin.getKind() == ExprKind::Or) {
|
||||||
|
// res = ...; // not(not(a) and not(b))
|
||||||
|
// }
|
||||||
|
|
||||||
|
// TODO(ntv): signed vs unsiged ??
|
||||||
|
// TODO(ntv): integer vs not ??
|
||||||
|
// TODO(ntv): float cmp
|
||||||
|
else if (bin.getKind() == ExprKind::EQ) {
|
||||||
|
res = builder->create<CmpIOp>(location, mlir::CmpIPredicate::EQ, a, b);
|
||||||
|
} else if (bin.getKind() == ExprKind::NE) {
|
||||||
|
res = builder->create<CmpIOp>(location, mlir::CmpIPredicate::NE, a, b);
|
||||||
|
} else if (bin.getKind() == ExprKind::LT) {
|
||||||
|
res = builder->create<CmpIOp>(location, mlir::CmpIPredicate::SLT, a, b);
|
||||||
|
} else if (bin.getKind() == ExprKind::LE) {
|
||||||
|
res = builder->create<CmpIOp>(location, mlir::CmpIPredicate::SLE, a, b);
|
||||||
|
} else if (bin.getKind() == ExprKind::GT) {
|
||||||
|
res = builder->create<CmpIOp>(location, mlir::CmpIPredicate::SGT, a, b);
|
||||||
|
} else if (bin.getKind() == ExprKind::GE) {
|
||||||
|
res = builder->create<CmpIOp>(location, mlir::CmpIPredicate::SGE, a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(ntv): do we want this?
|
||||||
|
// if (res && ((a->type().is_uint() && !b->type().is_uint()) ||
|
||||||
|
// (!a->type().is_uint() && b->type().is_uint()))) {
|
||||||
|
// std::stringstream ss;
|
||||||
|
// ss << "a: " << *a << "\t b: " << *b;
|
||||||
|
// res->getDefiningOperation()->emitWarning(
|
||||||
|
// "Mixing signed and unsigned integers: " + ss.str());
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto ter = e.dyn_cast<TernaryExpr>()) {
|
||||||
|
if (ter.getKind() == ExprKind::Select) {
|
||||||
|
auto *cond = emit(ter.getCond());
|
||||||
|
auto *lhs = emit(ter.getLHS());
|
||||||
|
auto *rhs = emit(ter.getRHS());
|
||||||
|
res = builder->create<SelectOp>(location, cond, lhs, rhs)->getResult();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto nar = e.dyn_cast<VariadicExpr>()) {
|
||||||
|
if (nar.getKind() == ExprKind::Alloc) {
|
||||||
|
auto exprs = emit(nar.getExprs());
|
||||||
|
auto types = nar.getTypes();
|
||||||
|
assert(types.size() == 1 && "Expected 1 type");
|
||||||
|
res =
|
||||||
|
builder->create<AllocOp>(location, types[0].cast<MemRefType>(), exprs)
|
||||||
|
->getResult();
|
||||||
|
} else if (nar.getKind() == ExprKind::Load) {
|
||||||
|
auto exprs = emit(nar.getExprs());
|
||||||
|
assert(exprs.size() > 1 && "Expected > 1 expr");
|
||||||
|
assert(nar.getTypes().empty() && "Expected no type");
|
||||||
|
SmallVector<Value *, 8> vals(exprs.begin() + 1, exprs.end());
|
||||||
|
res = builder->create<LoadOp>(location, exprs[0], vals)->getResult();
|
||||||
|
} else if (nar.getKind() == ExprKind::Store) {
|
||||||
|
auto exprs = emit(nar.getExprs());
|
||||||
|
assert(exprs.size() > 2 && "Expected > 2 expr");
|
||||||
|
assert(nar.getTypes().empty() && "Expected no type");
|
||||||
|
SmallVector<Value *, 8> vals(exprs.begin() + 2, exprs.end());
|
||||||
|
builder->create<StoreOp>(location, exprs[0], exprs[1], vals);
|
||||||
|
return nullptr;
|
||||||
|
} else if (nar.getKind() == ExprKind::VectorTypeCast) {
|
||||||
|
auto exprs = emit(nar.getExprs());
|
||||||
|
assert(exprs.size() == 1 && "Expected 1 expr");
|
||||||
|
auto types = nar.getTypes();
|
||||||
|
assert(types.size() == 1 && "Expected 1 type");
|
||||||
|
res = builder
|
||||||
|
->create<VectorTypeCastOp>(location, exprs[0],
|
||||||
|
types[0].cast<MemRefType>())
|
||||||
|
->getResult();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto expr = e.dyn_cast<StmtBlockLikeExpr>()) {
|
||||||
|
if (expr.getKind() == ExprKind::For) {
|
||||||
|
auto exprs = emit(expr.getExprs());
|
||||||
|
assert(exprs.size() == 3 && "Expected 3 exprs");
|
||||||
|
assert(expr.getTypes().empty() && "Expected no type");
|
||||||
|
auto lb =
|
||||||
|
exprs[0]->getDefiningInst()->cast<ConstantIndexOp>()->getValue();
|
||||||
|
auto ub =
|
||||||
|
exprs[1]->getDefiningInst()->cast<ConstantIndexOp>()->getValue();
|
||||||
|
auto step =
|
||||||
|
exprs[2]->getDefiningInst()->cast<ConstantIndexOp>()->getValue();
|
||||||
|
res = builder->createFor(location, lb, ub, step);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!res) {
|
||||||
|
e.print(llvm::errs() << "\nError @" << e.getStoragePtr() << ": ");
|
||||||
|
assert(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto resIter = ssaBindings.insert(std::make_pair(e, res));
|
||||||
|
(void)resIter;
|
||||||
|
assert(resIter.second && "insertion failed");
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value *, 8> MLIREmitter::emit(ArrayRef<Expr> exprs) {
|
||||||
|
return mlir::functional::map(
|
||||||
|
[this](Expr e) {
|
||||||
|
auto *res = this->emit(e);
|
||||||
|
LLVM_DEBUG(
|
||||||
|
printDefininingStatement(llvm::dbgs() << "\nEmitted: ", *res));
|
||||||
|
return res;
|
||||||
|
},
|
||||||
|
exprs);
|
||||||
|
}
|
||||||
|
|
||||||
|
void MLIREmitter::emitStmt(const Stmt &stmt) {
|
||||||
|
auto *block = builder->getBlock();
|
||||||
|
auto ip = builder->getInsertionPoint();
|
||||||
|
// Blocks are just a containing abstraction, they do not emit their RHS.
|
||||||
|
if (stmt.getRHS().getKind() != ExprKind::Block) {
|
||||||
|
auto *val = emit(stmt.getRHS());
|
||||||
|
if (!val) {
|
||||||
|
assert((stmt.getRHS().getKind() == ExprKind::Dealloc ||
|
||||||
|
stmt.getRHS().getKind() == ExprKind::Store) &&
|
||||||
|
"dealloc or store expected as the only 0-result ops");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
bind(stmt.getLHS(), val);
|
||||||
|
if (stmt.getRHS().getKind() == ExprKind::For) {
|
||||||
|
// Step into the loop.
|
||||||
|
builder->setInsertionPointToStart(cast<ForInst>(val)->getBody());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
emitStmts(stmt.getEnclosedStmts());
|
||||||
|
builder->setInsertionPoint(block, ip);
|
||||||
|
}
|
||||||
|
|
||||||
|
void MLIREmitter::emitStmts(ArrayRef<Stmt> stmts) {
|
||||||
|
for (auto &stmt : stmts) {
|
||||||
|
emitStmt(stmt);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace edsc
|
||||||
|
} // namespace mlir
|
|
@ -0,0 +1,398 @@
|
||||||
|
//===- Types.h - MLIR EDSC Type System Implementation -----------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The MLIR Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
#include "mlir/EDSC/Types.h"
|
||||||
|
#include "mlir/Support/STLExtras.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/ADT/Twine.h"
|
||||||
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
|
||||||
|
using llvm::errs;
|
||||||
|
using llvm::Twine;
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::edsc;
|
||||||
|
using namespace mlir::edsc::detail;
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace edsc {
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
struct ExprStorage {
|
||||||
|
ExprStorage(ExprKind kind) : kind(kind) {}
|
||||||
|
ExprKind kind;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct UnaryExprStorage : public ExprStorage {
|
||||||
|
UnaryExprStorage(ExprKind k, Expr expr) : ExprStorage(k), expr(expr) {}
|
||||||
|
Expr expr;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct BinaryExprStorage : public ExprStorage {
|
||||||
|
BinaryExprStorage(ExprKind k, Expr lhs, Expr rhs)
|
||||||
|
: ExprStorage(k), lhs(lhs), rhs(rhs) {}
|
||||||
|
Expr lhs, rhs;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct TernaryExprStorage : public ExprStorage {
|
||||||
|
TernaryExprStorage(ExprKind k, Expr cond, Expr lhs, Expr rhs)
|
||||||
|
: ExprStorage(k), cond(cond), lhs(lhs), rhs(rhs) {}
|
||||||
|
Expr cond, lhs, rhs;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct VariadicExprStorage : public ExprStorage {
|
||||||
|
VariadicExprStorage(ExprKind k, ArrayRef<Expr> exprs, ArrayRef<Type> types)
|
||||||
|
: ExprStorage(k), exprs(exprs.begin(), exprs.end()),
|
||||||
|
types(types.begin(), types.end()) {}
|
||||||
|
ArrayRef<Expr> exprs;
|
||||||
|
ArrayRef<Type> types;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct StmtStorage {
|
||||||
|
StmtStorage(Bindable lhs, Expr rhs, llvm::ArrayRef<Stmt> enclosedStmts)
|
||||||
|
: lhs(lhs), rhs(rhs), enclosedStmts(enclosedStmts) {}
|
||||||
|
Bindable lhs;
|
||||||
|
Expr rhs;
|
||||||
|
ArrayRef<Stmt> enclosedStmts;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
ScopedEDSCContext::ScopedEDSCContext() { Expr::globalAllocator() = &allocator; }
|
||||||
|
|
||||||
|
ScopedEDSCContext::~ScopedEDSCContext() { Expr::globalAllocator() = nullptr; }
|
||||||
|
|
||||||
|
ExprKind Expr::getKind() const { return storage->kind; }
|
||||||
|
|
||||||
|
Expr Expr::operator+(Expr other) const {
|
||||||
|
return BinaryExpr(ExprKind::Add, *this, other);
|
||||||
|
}
|
||||||
|
Expr Expr::operator-(Expr other) const {
|
||||||
|
return BinaryExpr(ExprKind::Sub, *this, other);
|
||||||
|
}
|
||||||
|
|
||||||
|
Expr Expr::operator==(Expr other) const {
|
||||||
|
return BinaryExpr(ExprKind::EQ, *this, other);
|
||||||
|
}
|
||||||
|
Expr Expr::operator!=(Expr other) const {
|
||||||
|
return BinaryExpr(ExprKind::NE, *this, other);
|
||||||
|
}
|
||||||
|
Expr Expr::operator<(Expr other) const {
|
||||||
|
return BinaryExpr(ExprKind::LT, *this, other);
|
||||||
|
}
|
||||||
|
Expr Expr::operator<=(Expr other) const {
|
||||||
|
return BinaryExpr(ExprKind::LE, *this, other);
|
||||||
|
}
|
||||||
|
Expr Expr::operator>(Expr other) const {
|
||||||
|
return BinaryExpr(ExprKind::GT, *this, other);
|
||||||
|
}
|
||||||
|
Expr Expr::operator>=(Expr other) const {
|
||||||
|
return BinaryExpr(ExprKind::GE, *this, other);
|
||||||
|
}
|
||||||
|
Expr Expr::operator&&(Expr other) const {
|
||||||
|
return BinaryExpr(ExprKind::And, *this, other);
|
||||||
|
}
|
||||||
|
Expr Expr::operator||(Expr other) const {
|
||||||
|
return BinaryExpr(ExprKind::Or, *this, other);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Free functions.
|
||||||
|
llvm::SmallVector<Bindable, 8> makeBindables(unsigned n) {
|
||||||
|
llvm::SmallVector<Bindable, 8> res;
|
||||||
|
res.reserve(n);
|
||||||
|
for (auto i = 0; i < n; ++i) {
|
||||||
|
res.push_back(Bindable());
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::SmallVector<Expr, 8> makeExprs(unsigned n) {
|
||||||
|
llvm::SmallVector<Expr, 8> res;
|
||||||
|
res.reserve(n);
|
||||||
|
for (auto i = 0; i < n; ++i) {
|
||||||
|
res.push_back(Expr());
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::SmallVector<Expr, 8> makeExprs(ArrayRef<Bindable> bindables) {
|
||||||
|
llvm::SmallVector<Expr, 8> res;
|
||||||
|
res.reserve(bindables.size());
|
||||||
|
for (auto b : bindables) {
|
||||||
|
res.push_back(b);
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
Expr alloc(llvm::ArrayRef<Expr> sizes, Type memrefType) {
|
||||||
|
return VariadicExpr(ExprKind::Alloc, sizes, memrefType);
|
||||||
|
}
|
||||||
|
|
||||||
|
Stmt Block(ArrayRef<Stmt> stmts) {
|
||||||
|
return Stmt(StmtBlockLikeExpr(ExprKind::Block, {}), stmts);
|
||||||
|
}
|
||||||
|
|
||||||
|
Expr dealloc(Expr memref) { return UnaryExpr(ExprKind::Dealloc, memref); }
|
||||||
|
|
||||||
|
Stmt For(Expr lb, Expr ub, Expr step, ArrayRef<Stmt> stmts) {
|
||||||
|
Bindable idx;
|
||||||
|
return For(idx, lb, ub, step, stmts);
|
||||||
|
}
|
||||||
|
|
||||||
|
Stmt For(const Bindable &idx, Expr lb, Expr ub, Expr step,
|
||||||
|
ArrayRef<Stmt> stmts) {
|
||||||
|
return Stmt(idx, StmtBlockLikeExpr(ExprKind::For, {lb, ub, step}), stmts);
|
||||||
|
}
|
||||||
|
|
||||||
|
Stmt ForNest(MutableArrayRef<Bindable> indices, ArrayRef<Expr> lbs,
|
||||||
|
ArrayRef<Expr> ubs, ArrayRef<Expr> steps,
|
||||||
|
ArrayRef<Stmt> enclosedStmts) {
|
||||||
|
assert(!indices.empty());
|
||||||
|
assert(indices.size() == lbs.size());
|
||||||
|
assert(indices.size() == ubs.size());
|
||||||
|
assert(indices.size() == steps.size());
|
||||||
|
Stmt curStmt =
|
||||||
|
For(indices.back(), lbs.back(), ubs.back(), steps.back(), enclosedStmts);
|
||||||
|
for (int64_t i = indices.size() - 2; i >= 0; --i) {
|
||||||
|
Stmt nextStmt = For(indices[i], lbs[i], ubs[i], steps[i], {curStmt});
|
||||||
|
curStmt = nextStmt;
|
||||||
|
}
|
||||||
|
return curStmt;
|
||||||
|
}
|
||||||
|
|
||||||
|
Expr load(Expr m, llvm::ArrayRef<Expr> indices) {
|
||||||
|
SmallVector<Expr, 8> exprs;
|
||||||
|
exprs.push_back(m);
|
||||||
|
exprs.insert(exprs.end(), indices.begin(), indices.end());
|
||||||
|
return VariadicExpr(ExprKind::Load, exprs);
|
||||||
|
}
|
||||||
|
|
||||||
|
Expr store(Expr val, Expr m, llvm::ArrayRef<Expr> indices) {
|
||||||
|
SmallVector<Expr, 8> exprs;
|
||||||
|
exprs.push_back(val);
|
||||||
|
exprs.push_back(m);
|
||||||
|
exprs.insert(exprs.end(), indices.begin(), indices.end());
|
||||||
|
return VariadicExpr(ExprKind::Store, exprs);
|
||||||
|
}
|
||||||
|
|
||||||
|
Expr select(Expr cond, Expr lhs, Expr rhs) {
|
||||||
|
return TernaryExpr(ExprKind::Select, cond, lhs, rhs);
|
||||||
|
}
|
||||||
|
|
||||||
|
Expr vector_type_cast(Expr memrefExpr, Type memrefType) {
|
||||||
|
return VariadicExpr(ExprKind::VectorTypeCast, {memrefExpr}, {memrefType});
|
||||||
|
}
|
||||||
|
|
||||||
|
void Expr::print(raw_ostream &os) const {
|
||||||
|
if (auto unbound = this->dyn_cast<Bindable>()) {
|
||||||
|
os << "bindable";
|
||||||
|
return;
|
||||||
|
} else if (auto bin = this->dyn_cast<BinaryExpr>()) {
|
||||||
|
os << bin.getLHS();
|
||||||
|
switch (bin.getKind()) {
|
||||||
|
case ExprKind::Add:
|
||||||
|
os << " + ";
|
||||||
|
break;
|
||||||
|
case ExprKind::Sub:
|
||||||
|
os << " - ";
|
||||||
|
break;
|
||||||
|
case ExprKind::LT:
|
||||||
|
os << " < ";
|
||||||
|
break;
|
||||||
|
case ExprKind::LE:
|
||||||
|
os << " <= ";
|
||||||
|
break;
|
||||||
|
case ExprKind::GT:
|
||||||
|
os << " > ";
|
||||||
|
break;
|
||||||
|
case ExprKind::GE:
|
||||||
|
os << " >= ";
|
||||||
|
break;
|
||||||
|
default: {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
os << bin.getRHS();
|
||||||
|
} else if (auto ter = this->dyn_cast<TernaryExpr>()) {
|
||||||
|
switch (ter.getKind()) {
|
||||||
|
case ExprKind::Select:
|
||||||
|
os << "select(" << ter.getCond() << ", " << ter.getLHS() << ", "
|
||||||
|
<< ter.getRHS() << ")";
|
||||||
|
return;
|
||||||
|
default: {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (auto nar = this->dyn_cast<VariadicExpr>()) {
|
||||||
|
switch (nar.getKind()) {
|
||||||
|
case ExprKind::Load:
|
||||||
|
os << "load( ... )";
|
||||||
|
return;
|
||||||
|
default: {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (auto stmtLikeExpr = this->dyn_cast<StmtBlockLikeExpr>()) {
|
||||||
|
auto exprs = stmtLikeExpr.getExprs();
|
||||||
|
assert(exprs.size() == 3 && "For StmtBlockLikeExpr expected 3 exprs");
|
||||||
|
switch (stmtLikeExpr.getKind()) {
|
||||||
|
// We only print the lb, ub and step here, which are the StmtBlockLike
|
||||||
|
// part of the `for` StmtBlockLikeExpr.
|
||||||
|
case ExprKind::For:
|
||||||
|
os << exprs[0] << " to " << exprs[1] << " step " << exprs[2];
|
||||||
|
return;
|
||||||
|
default: {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
os << "unknown_kind(" << static_cast<int>(getKind()) << ")";
|
||||||
|
}
|
||||||
|
|
||||||
|
void Expr::dump() const { this->print(llvm::errs()); }
|
||||||
|
|
||||||
|
llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Expr &expr) {
|
||||||
|
expr.print(os);
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
|
||||||
|
Bindable::Bindable(ExprKind kind)
|
||||||
|
: Expr(Expr::globalAllocator()->Allocate<detail::ExprStorage>()) {
|
||||||
|
// Initialize with placement new.
|
||||||
|
new (storage) detail::ExprStorage{kind};
|
||||||
|
}
|
||||||
|
|
||||||
|
UnaryExpr::UnaryExpr(ExprKind kind, Expr expr)
|
||||||
|
: Expr(Expr::globalAllocator()->Allocate<detail::UnaryExprStorage>()) {
|
||||||
|
// Initialize with placement new.
|
||||||
|
new (storage) detail::UnaryExprStorage{kind, expr};
|
||||||
|
}
|
||||||
|
Expr UnaryExpr::getExpr() const {
|
||||||
|
return static_cast<ImplType *>(storage)->expr;
|
||||||
|
}
|
||||||
|
|
||||||
|
BinaryExpr::BinaryExpr(ExprKind kind, Expr lhs, Expr rhs)
|
||||||
|
: Expr(Expr::globalAllocator()->Allocate<detail::BinaryExprStorage>()) {
|
||||||
|
// Initialize with placement new.
|
||||||
|
new (storage) detail::BinaryExprStorage{kind, lhs, rhs};
|
||||||
|
}
|
||||||
|
Expr BinaryExpr::getLHS() const {
|
||||||
|
return static_cast<ImplType *>(storage)->lhs;
|
||||||
|
}
|
||||||
|
Expr BinaryExpr::getRHS() const {
|
||||||
|
return static_cast<ImplType *>(storage)->rhs;
|
||||||
|
}
|
||||||
|
|
||||||
|
TernaryExpr::TernaryExpr(ExprKind kind, Expr cond, Expr lhs, Expr rhs)
|
||||||
|
: Expr(Expr::globalAllocator()->Allocate<detail::TernaryExprStorage>()) {
|
||||||
|
// Initialize with placement new.
|
||||||
|
new (storage) detail::TernaryExprStorage{kind, cond, lhs, rhs};
|
||||||
|
}
|
||||||
|
Expr TernaryExpr::getCond() const {
|
||||||
|
return static_cast<ImplType *>(storage)->cond;
|
||||||
|
}
|
||||||
|
Expr TernaryExpr::getLHS() const {
|
||||||
|
return static_cast<ImplType *>(storage)->lhs;
|
||||||
|
}
|
||||||
|
Expr TernaryExpr::getRHS() const {
|
||||||
|
return static_cast<ImplType *>(storage)->rhs;
|
||||||
|
}
|
||||||
|
|
||||||
|
VariadicExpr::VariadicExpr(ExprKind kind, ArrayRef<Expr> exprs,
|
||||||
|
ArrayRef<Type> types)
|
||||||
|
: Expr(Expr::globalAllocator()->Allocate<detail::VariadicExprStorage>()) {
|
||||||
|
// Initialize with placement new.
|
||||||
|
auto exprStorage = Expr::globalAllocator()->Allocate<Expr>(exprs.size());
|
||||||
|
std::uninitialized_copy(exprs.begin(), exprs.end(), exprStorage);
|
||||||
|
auto typeStorage = Expr::globalAllocator()->Allocate<Type>(types.size());
|
||||||
|
std::uninitialized_copy(types.begin(), types.end(), typeStorage);
|
||||||
|
new (storage) detail::VariadicExprStorage{
|
||||||
|
kind, ArrayRef<Expr>(exprStorage, exprs.size()),
|
||||||
|
ArrayRef<Type>(typeStorage, types.size())};
|
||||||
|
}
|
||||||
|
ArrayRef<Expr> VariadicExpr::getExprs() const {
|
||||||
|
return static_cast<ImplType *>(storage)->exprs;
|
||||||
|
}
|
||||||
|
ArrayRef<Type> VariadicExpr::getTypes() const {
|
||||||
|
return static_cast<ImplType *>(storage)->types;
|
||||||
|
}
|
||||||
|
|
||||||
|
Stmt::Stmt(const Bindable &lhs, const Expr &rhs,
|
||||||
|
llvm::ArrayRef<Stmt> enclosedStmts) {
|
||||||
|
storage = Expr::globalAllocator()->Allocate<detail::StmtStorage>();
|
||||||
|
// Initialize with placement new.
|
||||||
|
auto enclosedStmtStorage =
|
||||||
|
Expr::globalAllocator()->Allocate<Stmt>(enclosedStmts.size());
|
||||||
|
std::uninitialized_copy(enclosedStmts.begin(), enclosedStmts.end(),
|
||||||
|
enclosedStmtStorage);
|
||||||
|
new (storage) detail::StmtStorage{
|
||||||
|
lhs, rhs, ArrayRef<Stmt>(enclosedStmtStorage, enclosedStmts.size())};
|
||||||
|
}
|
||||||
|
|
||||||
|
Stmt::Stmt(const Expr &rhs, llvm::ArrayRef<Stmt> enclosedStmts)
|
||||||
|
: Stmt(Bindable(), rhs, enclosedStmts) {}
|
||||||
|
|
||||||
|
Stmt &Stmt::operator=(const Expr &expr) {
|
||||||
|
Stmt res(Bindable(), expr, {});
|
||||||
|
std::swap(res.storage, this->storage);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
Bindable Stmt::getLHS() const { return static_cast<ImplType *>(storage)->lhs; }
|
||||||
|
|
||||||
|
Expr Stmt::getRHS() const { return static_cast<ImplType *>(storage)->rhs; }
|
||||||
|
|
||||||
|
llvm::ArrayRef<Stmt> Stmt::getEnclosedStmts() const {
|
||||||
|
return storage->enclosedStmts;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Stmt::print(raw_ostream &os, Twine indent) const {
|
||||||
|
assert(storage && "Unexpected null storage,stmt must be bound to print");
|
||||||
|
auto lhs = getLHS();
|
||||||
|
auto rhs = getRHS();
|
||||||
|
|
||||||
|
if (auto stmt = rhs.dyn_cast<StmtBlockLikeExpr>()) {
|
||||||
|
switch (stmt.getKind()) {
|
||||||
|
case ExprKind::For:
|
||||||
|
os << indent << "for(idx(" << lhs << ")=" << rhs << ") {";
|
||||||
|
os << " // @" << storage;
|
||||||
|
os << "\n";
|
||||||
|
for (const auto &s : getEnclosedStmts()) {
|
||||||
|
if (!s.getRHS().isa<StmtBlockLikeExpr>()) {
|
||||||
|
os << indent << " ";
|
||||||
|
}
|
||||||
|
s.print(os, indent + " ");
|
||||||
|
os << ";\n";
|
||||||
|
}
|
||||||
|
os << indent << "}";
|
||||||
|
return;
|
||||||
|
default: {
|
||||||
|
// TODO(ntv): print more statement cases.
|
||||||
|
os << "TODO";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
os << "lhs(" << lhs << ") = " << rhs;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Stmt::dump() const { this->print(llvm::errs()); }
|
||||||
|
|
||||||
|
llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Stmt &stmt) {
|
||||||
|
stmt.print(os);
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace edsc
|
||||||
|
} // namespace mlir
|
|
@ -20,7 +20,7 @@
|
||||||
// operations.
|
// operations.
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
//
|
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
#include "mlir/IR/BuiltinOps.h"
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
#include "mlir/Pass.h"
|
#include "mlir/Pass.h"
|
||||||
|
|
|
@ -19,10 +19,13 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
#include "mlir/Analysis/AffineAnalysis.h"
|
#include "mlir/Analysis/AffineAnalysis.h"
|
||||||
#include "mlir/Analysis/MLFunctionMatcher.h"
|
#include "mlir/Analysis/MLFunctionMatcher.h"
|
||||||
#include "mlir/Analysis/Utils.h"
|
#include "mlir/Analysis/Utils.h"
|
||||||
#include "mlir/Analysis/VectorAnalysis.h"
|
#include "mlir/Analysis/VectorAnalysis.h"
|
||||||
|
#include "mlir/EDSC/MLIREmitter.h"
|
||||||
#include "mlir/IR/AffineExpr.h"
|
#include "mlir/IR/AffineExpr.h"
|
||||||
#include "mlir/IR/AffineMap.h"
|
#include "mlir/IR/AffineMap.h"
|
||||||
#include "mlir/IR/Attributes.h"
|
#include "mlir/IR/Attributes.h"
|
||||||
|
@ -41,9 +44,9 @@
|
||||||
#include "mlir/Transforms/Passes.h"
|
#include "mlir/Transforms/Passes.h"
|
||||||
|
|
||||||
#include "llvm/ADT/SetVector.h"
|
#include "llvm/ADT/SetVector.h"
|
||||||
|
#include "llvm/Support/Allocator.h"
|
||||||
#include "llvm/Support/Debug.h"
|
#include "llvm/Support/Debug.h"
|
||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
#include <type_traits>
|
|
||||||
|
|
||||||
///
|
///
|
||||||
/// Implements lowering of VectorTransferReadOp and VectorTransferWriteOp to a
|
/// Implements lowering of VectorTransferReadOp and VectorTransferWriteOp to a
|
||||||
|
@ -59,145 +62,365 @@ using namespace mlir;
|
||||||
|
|
||||||
#define DEBUG_TYPE "lower-vector-transfers"
|
#define DEBUG_TYPE "lower-vector-transfers"
|
||||||
|
|
||||||
/// Creates the Value for the sum of `a` and `b` without building a
|
/// This function emits the proper Value* at the place of insertion of b,
|
||||||
/// full-fledged AffineMap for all indices.
|
/// where each value is the proper ConstantOp or DimOp. Returns a vector with
|
||||||
|
/// these Value*. Note this function does not concern itself with hoisting of
|
||||||
|
/// constants and will produce redundant IR. Subsequent MLIR simplification
|
||||||
|
/// passes like LICM and CSE are expected to clean this up.
|
||||||
///
|
///
|
||||||
/// Prerequisites:
|
/// More specifically, a MemRefType has a shape vector in which:
|
||||||
/// `a` and `b` must be of IndexType.
|
/// - constant ranks are embedded explicitly with their value;
|
||||||
static Value *add(FuncBuilder *b, Location loc, Value *v, Value *w) {
|
/// - symbolic ranks are represented implicitly by -1 and need to be recovered
|
||||||
assert(v->getType().isa<IndexType>() && "v must be of IndexType");
|
/// with a DimOp operation.
|
||||||
assert(w->getType().isa<IndexType>() && "w must be of IndexType");
|
///
|
||||||
auto *context = b->getContext();
|
/// Example:
|
||||||
auto d0 = getAffineDimExpr(0, context);
|
/// When called on:
|
||||||
auto d1 = getAffineDimExpr(1, context);
|
///
|
||||||
auto map = AffineMap::get(2, 0, {d0 + d1}, {});
|
/// ```mlir
|
||||||
return b->create<AffineApplyOp>(loc, map, ArrayRef<mlir::Value *>{v, w})
|
/// memref<?x3x4x?x5xf32>
|
||||||
->getResult(0);
|
/// ```
|
||||||
|
///
|
||||||
|
/// This emits MLIR similar to:
|
||||||
|
///
|
||||||
|
/// ```mlir
|
||||||
|
/// %d0 = dim %0, 0 : memref<?x3x4x?x5xf32>
|
||||||
|
/// %c3 = constant 3 : index
|
||||||
|
/// %c4 = constant 4 : index
|
||||||
|
/// %d1 = dim %0, 0 : memref<?x3x4x?x5xf32>
|
||||||
|
/// %c5 = constant 5 : index
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// and returns the vector with {%d0, %c3, %c4, %d1, %c5}.
|
||||||
|
bool isDynamicSize(int size) { return size < 0; }
|
||||||
|
SmallVector<Value *, 8> getMemRefSizes(FuncBuilder *b, Location loc,
|
||||||
|
Value *memRef) {
|
||||||
|
auto memRefType = memRef->getType().template cast<MemRefType>();
|
||||||
|
SmallVector<Value *, 8> res;
|
||||||
|
res.reserve(memRefType.getShape().size());
|
||||||
|
unsigned countSymbolicShapes = 0;
|
||||||
|
for (int size : memRefType.getShape()) {
|
||||||
|
if (isDynamicSize(size)) {
|
||||||
|
res.push_back(b->create<DimOp>(loc, memRef, countSymbolicShapes++));
|
||||||
|
} else {
|
||||||
|
res.push_back(b->create<ConstantIndexOp>(loc, size));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
struct LowerVectorTransfersState : public MLFuncGlobalLoweringState {
|
/// Helper structure to hold information about loop nest, clipped accesses to
|
||||||
// Top of the function constant zero index.
|
/// the original scalar MemRef as well as full accesses to temporary MemRef in
|
||||||
Value *zero;
|
/// local storage.
|
||||||
|
struct VectorTransferAccessInfo {
|
||||||
|
// `ivs` are bound for `For` Stmt at `For` Stmt construction time.
|
||||||
|
llvm::SmallVector<edsc::Bindable, 8> ivs;
|
||||||
|
llvm::SmallVector<edsc::Expr, 8> lowerBoundsExprs;
|
||||||
|
llvm::SmallVector<edsc::Expr, 8> upperBoundsExprs;
|
||||||
|
llvm::SmallVector<edsc::Expr, 8> stepExprs;
|
||||||
|
llvm::SmallVector<edsc::Expr, 8> clippedScalarAccessExprs;
|
||||||
|
llvm::SmallVector<edsc::Expr, 8> tmpAccessExprs;
|
||||||
};
|
};
|
||||||
} // namespace
|
|
||||||
|
|
||||||
/// Performs simple lowering into a combination of:
|
template <typename VectorTransferOpTy> class VectorTransferRewriter {
|
||||||
/// 1. local memory allocation,
|
public:
|
||||||
/// 2. vector_load/vector_store from/to local buffer
|
/// Perform the rewrite using the `emitter`.
|
||||||
/// 3. perfect loop nest over scalar loads/stores from/to remote memory.
|
VectorTransferRewriter(VectorTransferOpTy *transfer,
|
||||||
///
|
|
||||||
/// This is a simple sketch for now but does the job.
|
|
||||||
// TODO(ntv): This function has a lot of code conditioned on the template
|
|
||||||
// argument being one of the two types. Extract the common behavior into helper
|
|
||||||
// functions and detemplatizing it.
|
|
||||||
template <typename VectorTransferOpTy>
|
|
||||||
static void rewriteAsLoops(VectorTransferOpTy *transfer,
|
|
||||||
MLFuncLoweringRewriter *rewriter,
|
MLFuncLoweringRewriter *rewriter,
|
||||||
LowerVectorTransfersState *state) {
|
MLFuncGlobalLoweringState *state);
|
||||||
static_assert(
|
|
||||||
std::is_same<VectorTransferOpTy, VectorTransferReadOp>::value ||
|
|
||||||
std::is_same<VectorTransferOpTy, VectorTransferWriteOp>::value,
|
|
||||||
"Must be called on either VectorTransferReadOp or VectorTransferWriteOp");
|
|
||||||
auto vectorType = transfer->getVectorType();
|
|
||||||
auto vectorShape = vectorType.getShape();
|
|
||||||
// tmpMemRefType is used for staging the transfer in a local scalar buffer.
|
|
||||||
auto tmpMemRefType =
|
|
||||||
MemRefType::get(vectorShape, vectorType.getElementType(), {}, 0);
|
|
||||||
// vectorMemRefType is a view of tmpMemRefType as one vector.
|
|
||||||
auto vectorMemRefType = MemRefType::get({1}, vectorType, {}, 0);
|
|
||||||
|
|
||||||
// Get the ML function builder.
|
/// Perform the rewrite using the `emitter`.
|
||||||
// We need access to the Function builder stored internally in the
|
void rewrite();
|
||||||
// MLFunctionLoweringRewriter general rewriting API does not provide
|
|
||||||
// ML-specific functions (ForInst and Block manipulation). While we could
|
|
||||||
// forward them or define a whole rewriting chain based on MLFunctionBuilder
|
|
||||||
// instead of Builer, the code for it would be duplicate boilerplate. As we
|
|
||||||
// go towards unifying ML and CFG functions, this separation will disappear.
|
|
||||||
FuncBuilder &b = *rewriter->getBuilder();
|
|
||||||
|
|
||||||
// 1. First allocate the local buffer in fast memory.
|
/// Helper class which creates clipped memref accesses to support lowering of
|
||||||
// TODO(ntv): CL memory space.
|
/// the vector_transfer operation.
|
||||||
// TODO(ntv): Allocation padding for potential bank conflicts (e.g. GPUs).
|
VectorTransferAccessInfo makeVectorTransferAccessInfo();
|
||||||
auto tmpScalarAlloc = b.create<AllocOp>(transfer->getLoc(), tmpMemRefType);
|
|
||||||
auto vecView = b.create<VectorTypeCastOp>(
|
|
||||||
transfer->getLoc(), tmpScalarAlloc->getResult(), vectorMemRefType);
|
|
||||||
|
|
||||||
// 2. Store the vector to local storage in case of a vector_transfer_write.
|
private:
|
||||||
// TODO(ntv): This vector_store operation should be further lowered in the
|
VectorTransferOpTy *transfer;
|
||||||
// case of GPUs.
|
MLFuncLoweringRewriter *rewriter;
|
||||||
if (std::is_same<VectorTransferOpTy, VectorTransferWriteOp>::value) {
|
MLFuncGlobalLoweringState *state;
|
||||||
b.create<StoreOp>(vecView->getLoc(), transfer->getVector(),
|
|
||||||
vecView->getResult(),
|
MemRefType memrefType;
|
||||||
ArrayRef<mlir::Value *>{state->zero});
|
ArrayRef<int> memrefShape;
|
||||||
|
VectorType vectorType;
|
||||||
|
ArrayRef<int> vectorShape;
|
||||||
|
AffineMap permutationMap;
|
||||||
|
|
||||||
|
/// Used for staging the transfer in a local scalar buffer.
|
||||||
|
MemRefType tmpMemRefType;
|
||||||
|
/// View of tmpMemRefType as one vector, used in vector load/store to tmp
|
||||||
|
/// buffer.
|
||||||
|
MemRefType vectorMemRefType;
|
||||||
|
|
||||||
|
// EDSC `emitter` and Bindables that are pre-bound at construction time.
|
||||||
|
// vectorSizes are bound to the actual constant sizes of vectorType.
|
||||||
|
llvm::SmallVector<edsc::Bindable, 8> vectorSizes;
|
||||||
|
// accesses are bound to transfer->getIndices()
|
||||||
|
llvm::SmallVector<edsc::Bindable, 8> accesses;
|
||||||
|
// `zero` and `one` are bound to locally scoped constants.
|
||||||
|
// `scalarMemRef` is bound to `transfer->getMemRef()`.
|
||||||
|
edsc::Bindable zero, one, scalarMemRef;
|
||||||
|
edsc::MLIREmitter emitter;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // end anonymous namespace
|
||||||
|
|
||||||
|
/// Consider the case:
|
||||||
|
///
|
||||||
|
/// ```mlir {.mlir}
|
||||||
|
/// // Read the slice `%A[%i0, %i1:%i1+256, %i2:%i2+32]` into
|
||||||
|
/// // vector<32x256xf32> and pad with %f0 to handle the boundary case:
|
||||||
|
/// %f0 = constant 0.0f : f32
|
||||||
|
/// for %i0 = 0 to %0 {
|
||||||
|
/// for %i1 = 0 to %1 step 256 {
|
||||||
|
/// for %i2 = 0 to %2 step 32 {
|
||||||
|
/// %v = vector_transfer_read %A, %i0, %i1, %i2, %f0
|
||||||
|
/// {permutation_map: (d0, d1, d2) -> (d2, d1)} :
|
||||||
|
/// (memref<?x?x?xf32>, index, index, f32) -> vector<32x256xf32>
|
||||||
|
/// }}}
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// The following constructs the `loadAccessExpr` that supports the emission of
|
||||||
|
/// MLIR resembling:
|
||||||
|
///
|
||||||
|
/// ```mlir
|
||||||
|
/// for %d1 = 0 to 256 {
|
||||||
|
/// for %d2 = 0 to 32 {
|
||||||
|
/// %s = %A[%i0, %i1 + %d1, %i2 + %d2] : f32
|
||||||
|
/// %tmp[%d2, %d1] = %s
|
||||||
|
/// }
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// Notice in particular the order of loops iterating over the vector size
|
||||||
|
/// (i.e. 256x32 instead of 32x256). This results in contiguous accesses along
|
||||||
|
/// the most minor dimension of the original scalar tensor. On many hardware
|
||||||
|
/// architectures this will result in better utilization of the underlying
|
||||||
|
/// memory subsystem (e.g. prefetchers, DMAs, #memory transactions, etc...).
|
||||||
|
///
|
||||||
|
/// This additionally performs clipping as described in
|
||||||
|
/// `VectorTransferRewriter<VectorTransferReadOp>::rewrite` by emitting:
|
||||||
|
///
|
||||||
|
/// ```mlir-dsc
|
||||||
|
/// select(i + ii < zero, zero, select(i + ii < N, i + ii, N - one))
|
||||||
|
/// ```
|
||||||
|
template <typename VectorTransferOpTy>
|
||||||
|
VectorTransferAccessInfo
|
||||||
|
VectorTransferRewriter<VectorTransferOpTy>::makeVectorTransferAccessInfo() {
|
||||||
|
using namespace mlir::edsc;
|
||||||
|
|
||||||
|
// Create Bindable objects for ivs, they will be bound at `For` Stmt
|
||||||
|
// construction.
|
||||||
|
auto ivs = makeBindables(vectorShape.size());
|
||||||
|
|
||||||
|
// Create and bind Bindables to refer to the Value for memref sizes.
|
||||||
|
auto memRefSizes = makeBindables(memrefShape.size());
|
||||||
|
auto memrefSizeValues = getMemRefSizes(
|
||||||
|
emitter.getBuilder(), emitter.getLocation(), transfer->getMemRef());
|
||||||
|
assert(memrefSizeValues.size() == memRefSizes.size());
|
||||||
|
// Bind
|
||||||
|
emitter.bindZipRange(llvm::zip(memRefSizes, memrefSizeValues));
|
||||||
|
|
||||||
|
// Create the edsc::Expr for the clipped and transposes access expressions
|
||||||
|
// using the permutationMap. Additionally, capture the index accessing the
|
||||||
|
// most minor dimension.
|
||||||
|
int coalescingIndex = -1;
|
||||||
|
auto clippedScalarAccessExprs = makeExprs(accesses);
|
||||||
|
auto tmpAccessExprs = makeExprs(ivs);
|
||||||
|
for (auto it : llvm::enumerate(permutationMap.getResults())) {
|
||||||
|
if (auto affineExpr = it.value().template dyn_cast<AffineDimExpr>()) {
|
||||||
|
auto pos = affineExpr.getPosition();
|
||||||
|
auto i = clippedScalarAccessExprs[pos];
|
||||||
|
auto ii = ivs[it.index()];
|
||||||
|
auto N = memRefSizes[pos];
|
||||||
|
clippedScalarAccessExprs[pos] =
|
||||||
|
select(i + ii < zero, zero, select(i + ii < N, i + ii, N - one));
|
||||||
|
if (pos == clippedScalarAccessExprs.size() - 1) {
|
||||||
|
// If a result of the permutation_map accesses the most minor dimension
|
||||||
|
// then we record it.
|
||||||
|
coalescingIndex = it.index();
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. Emit the loop-nest.
|
|
||||||
// TODO(ntv): Invert the mapping and indexing contiguously in the remote
|
|
||||||
// memory.
|
|
||||||
// TODO(ntv): Handle broadcast / slice properly.
|
|
||||||
auto permutationMap = transfer->getPermutationMap();
|
|
||||||
SetVector<ForInst *> loops;
|
|
||||||
SmallVector<Value *, 8> accessIndices(transfer->getIndices());
|
|
||||||
for (auto it : llvm::enumerate(transfer->getVectorType().getShape())) {
|
|
||||||
auto composed = composeWithUnboundedMap(
|
|
||||||
getAffineDimExpr(it.index(), b.getContext()), permutationMap);
|
|
||||||
auto *forInst = b.createFor(transfer->getLoc(), 0, it.value());
|
|
||||||
loops.insert(forInst);
|
|
||||||
// Setting the insertion point to the innermost loop achieves nesting.
|
|
||||||
b.setInsertionPointToStart(loops.back()->getBody());
|
|
||||||
if (composed == getAffineConstantExpr(0, b.getContext())) {
|
|
||||||
transfer->emitWarning(
|
|
||||||
"Redundant copy can be implemented as a vector broadcast");
|
|
||||||
} else {
|
} else {
|
||||||
auto dim = composed.template cast<AffineDimExpr>();
|
// Sanity check.
|
||||||
assert(accessIndices.size() > dim.getPosition());
|
assert(it.value().template cast<AffineConstantExpr>().getValue() == 0 &&
|
||||||
accessIndices[dim.getPosition()] =
|
"Expected dim or 0 in permutationMap");
|
||||||
::add(&b, transfer->getLoc(), accessIndices[dim.getPosition()],
|
|
||||||
loops.back());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. Emit memory operations within the loops.
|
// Create the proper bindables for lbs, ubs and steps. Additionally, if we
|
||||||
// TODO(ntv): SelectOp + padding value for load out-of-bounds.
|
// recorded a coalescing index, permute the loop informations.
|
||||||
if (std::is_same<VectorTransferOpTy, VectorTransferReadOp>::value) {
|
auto lbs = makeBindables(ivs.size());
|
||||||
// VectorTransferReadOp.
|
auto ubs = makeExprs(vectorSizes);
|
||||||
// a. read scalar from remote;
|
auto steps = makeBindables(ivs.size());
|
||||||
// b. write scalar to local.
|
if (coalescingIndex >= 0) {
|
||||||
auto scalarLoad = b.create<LoadOp>(transfer->getLoc(),
|
std::swap(ivs[coalescingIndex], ivs.back());
|
||||||
transfer->getMemRef(), accessIndices);
|
std::swap(lbs[coalescingIndex], lbs.back());
|
||||||
b.create<StoreOp>(transfer->getLoc(), scalarLoad->getResult(),
|
std::swap(ubs[coalescingIndex], ubs.back());
|
||||||
tmpScalarAlloc->getResult(),
|
std::swap(steps[coalescingIndex], steps.back());
|
||||||
functional::map([](Value *val) { return val; }, loops));
|
}
|
||||||
} else {
|
emitter
|
||||||
// VectorTransferWriteOp.
|
.template bindZipRangeConstants<ConstantIndexOp>(
|
||||||
// a. read scalar from local;
|
llvm::zip(lbs, SmallVector<int, 8>(ivs.size(), 0)))
|
||||||
// b. write scalar to remote.
|
.template bindZipRangeConstants<ConstantIndexOp>(
|
||||||
auto scalarLoad = b.create<LoadOp>(
|
llvm::zip(steps, SmallVector<int, 8>(ivs.size(), 1)));
|
||||||
transfer->getLoc(), tmpScalarAlloc->getResult(),
|
|
||||||
functional::map([](Value *val) { return val; }, loops));
|
return VectorTransferAccessInfo{ivs,
|
||||||
b.create<StoreOp>(transfer->getLoc(), scalarLoad->getResult(),
|
makeExprs(lbs),
|
||||||
transfer->getMemRef(), accessIndices);
|
ubs,
|
||||||
|
makeExprs(steps),
|
||||||
|
clippedScalarAccessExprs,
|
||||||
|
tmpAccessExprs};
|
||||||
}
|
}
|
||||||
|
|
||||||
// 5. Read the vector from local storage in case of a vector_transfer_read.
|
template <typename VectorTransferOpTy>
|
||||||
// TODO(ntv): This vector_load operation should be further lowered in the
|
VectorTransferRewriter<VectorTransferOpTy>::VectorTransferRewriter(
|
||||||
// case of GPUs.
|
VectorTransferOpTy *transfer, MLFuncLoweringRewriter *rewriter,
|
||||||
llvm::SmallVector<Value *, 1> newResults = {};
|
MLFuncGlobalLoweringState *state)
|
||||||
if (std::is_same<VectorTransferOpTy, VectorTransferReadOp>::value) {
|
: transfer(transfer), rewriter(rewriter), state(state),
|
||||||
b.setInsertionPoint(cast<OperationInst>(transfer->getInstruction()));
|
memrefType(transfer->getMemRefType()), memrefShape(memrefType.getShape()),
|
||||||
auto *vector = b.create<LoadOp>(transfer->getLoc(), vecView->getResult(),
|
vectorType(transfer->getVectorType()), vectorShape(vectorType.getShape()),
|
||||||
ArrayRef<Value *>{state->zero})
|
permutationMap(transfer->getPermutationMap()),
|
||||||
->getResult();
|
tmpMemRefType(
|
||||||
newResults.push_back(vector);
|
MemRefType::get(vectorShape, vectorType.getElementType(), {}, 0)),
|
||||||
|
vectorMemRefType(MemRefType::get({1}, vectorType, {}, 0)),
|
||||||
|
vectorSizes(edsc::makeBindables(vectorShape.size())),
|
||||||
|
emitter(edsc::MLIREmitter(rewriter->getBuilder(), transfer->getLoc())) {
|
||||||
|
// Bind the Bindable.
|
||||||
|
SmallVector<Value *, 8> transferIndices(transfer->getIndices());
|
||||||
|
accesses = edsc::makeBindables(transferIndices.size());
|
||||||
|
emitter.bind(scalarMemRef, transfer->getMemRef())
|
||||||
|
.template bindConstant<ConstantIndexOp>(zero, 0)
|
||||||
|
.template bindConstant<ConstantIndexOp>(one, 1)
|
||||||
|
.template bindZipRangeConstants<ConstantIndexOp>(
|
||||||
|
llvm::zip(vectorSizes, vectorShape))
|
||||||
|
.template bindZipRange(llvm::zip(accesses, transfer->getIndices()));
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Lowers VectorTransferReadOp into a combination of:
|
||||||
|
/// 1. local memory allocation;
|
||||||
|
/// 2. perfect loop nest over:
|
||||||
|
/// a. scalar load from local buffers (viewed as a scalar memref);
|
||||||
|
/// a. scalar store to original memref (with clipping).
|
||||||
|
/// 3. vector_load from local buffer (viewed as a memref<1 x vector>);
|
||||||
|
/// 4. local memory deallocation.
|
||||||
|
///
|
||||||
|
/// Lowers the data transfer part of a VectorTransferReadOp while ensuring no
|
||||||
|
/// out-of-bounds accesses are possible. Out-of-bounds behavior is handled by
|
||||||
|
/// clipping. This means that a given value in memory can be read multiple
|
||||||
|
/// times and concurrently.
|
||||||
|
///
|
||||||
|
/// Important notes about clipping and "full-tiles only" abstraction:
|
||||||
|
/// =================================================================
|
||||||
|
/// When using clipping for dealing with boundary conditions, the same edge
|
||||||
|
/// value will appear multiple times (a.k.a edge padding). This is fine if the
|
||||||
|
/// subsequent vector operations are all data-parallel but **is generally
|
||||||
|
/// incorrect** in the presence of reductions or extract operations.
|
||||||
|
///
|
||||||
|
/// More generally, clipping is a scalar abstraction that is expected to work
|
||||||
|
/// fine as a baseline for CPUs and GPUs but not for vector_load and DMAs.
|
||||||
|
/// To deal with real vector_load and DMAs, a "padded allocation + view"
|
||||||
|
/// abstraction with the ability to read out-of-memref-bounds (but still within
|
||||||
|
/// the allocated region) is necessary.
|
||||||
|
///
|
||||||
|
/// Whether using scalar loops or vector_load/DMAs to perform the transfer,
|
||||||
|
/// junk values will be materialized in the vectors and generally need to be
|
||||||
|
/// filtered out and replaced by the "neutral element". This neutral element is
|
||||||
|
/// op-dependent so, in the future, we expect to create a vector filter and
|
||||||
|
/// apply it to a splatted constant vector with the proper neutral element at
|
||||||
|
/// each ssa-use. This filtering is not necessary for pure data-parallel
|
||||||
|
/// operations.
|
||||||
|
///
|
||||||
|
/// In the case of vector_store/DMAs, Read-Modify-Write will be required, which
|
||||||
|
/// also have concurrency implications. Note that by using clipped scalar stores
|
||||||
|
/// in the presence of data-parallel only operations, we generate code that
|
||||||
|
/// writes the same value multiple time on the edge locations.
|
||||||
|
///
|
||||||
|
/// TODO(ntv): implement alternatives to clipping.
|
||||||
|
/// TODO(ntv): support non-data-parallel operations.
|
||||||
|
template <> void VectorTransferRewriter<VectorTransferReadOp>::rewrite() {
|
||||||
|
using namespace mlir::edsc;
|
||||||
|
|
||||||
|
// Build the AccessInfo which contain all the information needed to build the
|
||||||
|
// perfectly nest loop nest to perform clipped reads and local writes.
|
||||||
|
auto accessInfo = makeVectorTransferAccessInfo();
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
auto &ivs = accessInfo.ivs;
|
||||||
|
auto &lbs = accessInfo.lowerBoundsExprs;
|
||||||
|
auto &ubs = accessInfo.upperBoundsExprs;
|
||||||
|
auto &steps = accessInfo.stepExprs;
|
||||||
|
Stmt scalarValue, vectorValue, tmpAlloc, tmpDealloc, vectorView;
|
||||||
|
Stmt block = edsc::Block({
|
||||||
|
tmpAlloc = alloc(tmpMemRefType),
|
||||||
|
vectorView = vector_type_cast(tmpAlloc, vectorMemRefType),
|
||||||
|
ForNest(ivs, lbs, ubs, steps, {
|
||||||
|
scalarValue = load(scalarMemRef, accessInfo.clippedScalarAccessExprs),
|
||||||
|
store(scalarValue, tmpAlloc, accessInfo.tmpAccessExprs),
|
||||||
|
}),
|
||||||
|
vectorValue = load(vectorView, zero),
|
||||||
|
tmpDealloc = dealloc(tmpAlloc.getLHS())});
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
|
// Emit the MLIR.
|
||||||
|
emitter.emitStmt(block);
|
||||||
|
|
||||||
|
// Finalize rewriting.
|
||||||
|
transfer->replaceAllUsesWith(emitter.getValue(vectorValue.getLHS()));
|
||||||
|
transfer->erase();
|
||||||
}
|
}
|
||||||
|
|
||||||
// 6. Free the local buffer.
|
/// Lowers VectorTransferWriteOp into a combination of:
|
||||||
b.setInsertionPoint(transfer->getInstruction());
|
/// 1. local memory allocation;
|
||||||
b.create<DeallocOp>(transfer->getLoc(), tmpScalarAlloc);
|
/// 2. vector_store to local buffer (viewed as a memref<1 x vector>);
|
||||||
|
/// 3. perfect loop nest over:
|
||||||
|
/// a. scalar load from local buffers (viewed as a scalar memref);
|
||||||
|
/// a. scalar store to original memref (with clipping).
|
||||||
|
/// 4. local memory deallocation.
|
||||||
|
///
|
||||||
|
/// More specifically, lowers the data transfer part while ensuring no
|
||||||
|
/// out-of-bounds accesses are possible. Out-of-bounds behavior is handled by
|
||||||
|
/// clipping. This means that a given value in memory can be written to multiple
|
||||||
|
/// times and concurrently.
|
||||||
|
///
|
||||||
|
/// See `Important notes about clipping and full-tiles only abstraction` in the
|
||||||
|
/// description of `readClipped` above.
|
||||||
|
///
|
||||||
|
/// TODO(ntv): implement alternatives to clipping.
|
||||||
|
/// TODO(ntv): support non-data-parallel operations.
|
||||||
|
template <> void VectorTransferRewriter<VectorTransferWriteOp>::rewrite() {
|
||||||
|
using namespace mlir::edsc;
|
||||||
|
|
||||||
// 7. It is now safe to erase the instruction.
|
// Build the AccessInfo which contain all the information needed to build the
|
||||||
rewriter->replaceOp(transfer->getInstruction(), newResults);
|
// perfectly nest loop nest to perform local reads and clipped writes.
|
||||||
|
auto accessInfo = makeVectorTransferAccessInfo();
|
||||||
|
|
||||||
|
// Bind vector value for the vector_transfer_write.
|
||||||
|
Bindable vectorValue;
|
||||||
|
emitter.bind(vectorValue, transfer->getVector());
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
auto &ivs = accessInfo.ivs;
|
||||||
|
auto &lbs = accessInfo.lowerBoundsExprs;
|
||||||
|
auto &ubs = accessInfo.upperBoundsExprs;
|
||||||
|
auto &steps = accessInfo.stepExprs;
|
||||||
|
Stmt scalarValue, tmpAlloc, tmpDealloc, vectorView;
|
||||||
|
Stmt block = edsc::Block({
|
||||||
|
tmpAlloc = alloc(tmpMemRefType),
|
||||||
|
vectorView = vector_type_cast(tmpAlloc, vectorMemRefType),
|
||||||
|
store(vectorValue, vectorView, {zero}),
|
||||||
|
ForNest(ivs, lbs, ubs, steps, {
|
||||||
|
scalarValue = load(tmpAlloc, accessInfo.tmpAccessExprs),
|
||||||
|
store(scalarValue, scalarMemRef, accessInfo.clippedScalarAccessExprs),
|
||||||
|
}),
|
||||||
|
tmpDealloc = dealloc(tmpAlloc.getLHS())});
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
|
// Emit the MLIR.
|
||||||
|
emitter.emitStmt(block);
|
||||||
|
|
||||||
|
// Finalize rewriting.
|
||||||
|
transfer->erase();
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -212,18 +435,15 @@ public:
|
||||||
return matchSuccess();
|
return matchSuccess();
|
||||||
return matchFailure();
|
return matchFailure();
|
||||||
}
|
}
|
||||||
|
|
||||||
void rewriteOpInst(OperationInst *op,
|
void rewriteOpInst(OperationInst *op,
|
||||||
MLFuncGlobalLoweringState *funcWiseState,
|
MLFuncGlobalLoweringState *funcWiseState,
|
||||||
std::unique_ptr<PatternState> opState,
|
std::unique_ptr<PatternState> opState,
|
||||||
MLFuncLoweringRewriter *rewriter) const override {
|
MLFuncLoweringRewriter *rewriter) const override {
|
||||||
rewriteAsLoops(&*op->dyn_cast<VectorTransferOpTy>(), rewriter,
|
VectorTransferRewriter<VectorTransferOpTy>(
|
||||||
static_cast<LowerVectorTransfersState *>(funcWiseState));
|
&*op->dyn_cast<VectorTransferOpTy>(), rewriter, funcWiseState)
|
||||||
|
.rewrite();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
struct LowerVectorTransfersPass
|
struct LowerVectorTransfersPass
|
||||||
: public MLPatternLoweringPass<
|
: public MLPatternLoweringPass<
|
||||||
|
@ -232,13 +452,8 @@ struct LowerVectorTransfersPass
|
||||||
LowerVectorTransfersPass()
|
LowerVectorTransfersPass()
|
||||||
: MLPatternLoweringPass(&LowerVectorTransfersPass::passID) {}
|
: MLPatternLoweringPass(&LowerVectorTransfersPass::passID) {}
|
||||||
|
|
||||||
std::unique_ptr<MLFuncGlobalLoweringState>
|
// Thread-safe RAII context with local scope. BumpPtrAllocator freed on exit.
|
||||||
makeFuncWiseState(Function *f) const override {
|
edsc::ScopedEDSCContext raiiContext;
|
||||||
auto state = llvm::make_unique<LowerVectorTransfersState>();
|
|
||||||
auto builder = FuncBuilder(f);
|
|
||||||
state->zero = builder.create<ConstantIndexOp>(builder.getUnknownLoc(), 0);
|
|
||||||
return state;
|
|
||||||
}
|
|
||||||
|
|
||||||
static char passID;
|
static char passID;
|
||||||
};
|
};
|
||||||
|
|
|
@ -1,26 +1,55 @@
|
||||||
// RUN: mlir-opt %s -lower-vector-transfers | FileCheck %s
|
// RUN: mlir-opt %s -lower-vector-transfers | FileCheck %s
|
||||||
|
|
||||||
// CHECK: #[[ADD:map[0-9]+]] = (d0, d1) -> (d0 + d1)
|
// CHECK: #[[ADD:map[0-9]+]] = (d0, d1) -> (d0 + d1)
|
||||||
|
// CHECK: #[[SUB:map[0-9]+]] = (d0, d1) -> (d0 - d1)
|
||||||
|
// CHECK-LABEL: mlfunc @materialize_read(%arg0: index, %arg1: index, %arg2: index, %arg3: index) {
|
||||||
mlfunc @materialize_read(%M: index, %N: index, %O: index, %P: index) {
|
mlfunc @materialize_read(%M: index, %N: index, %O: index, %P: index) {
|
||||||
%A = alloc (%M, %N, %O, %P) : memref<?x?x?x?xf32, 0>
|
// CHECK-NEXT: %0 = alloc(%arg0, %arg1, %arg2, %arg3) : memref<?x?x?x?xf32>
|
||||||
// CHECK: for %i0 = 0 to %arg0 step 3 {
|
// CHECK-NEXT: for %i0 = 0 to %arg0 step 3 {
|
||||||
// CHECK-NEXT: for %i1 = 0 to %arg1 {
|
// CHECK-NEXT: for %i1 = 0 to %arg1 {
|
||||||
// CHECK-NEXT: for %i2 = 0 to %arg2 {
|
// CHECK-NEXT: for %i2 = 0 to %arg2 {
|
||||||
// CHECK-NEXT: for %i3 = 0 to %arg3 step 5 {
|
// CHECK-NEXT: for %i3 = 0 to %arg3 step 5 {
|
||||||
// CHECK-NEXT: %1 = alloc() : memref<5x4x3xf32>
|
// CHECK-NEXT: %c0 = constant 0 : index
|
||||||
// CHECK-NEXT: %2 = vector_type_cast %1 : memref<5x4x3xf32>, memref<1xvector<5x4x3xf32>>
|
// CHECK-NEXT: %c1 = constant 1 : index
|
||||||
// CHECK-NEXT: for %i4 = 0 to 5 {
|
// CHECK: %1 = dim %0, 0 : memref<?x?x?x?xf32>
|
||||||
// CHECK-NEXT: %3 = affine_apply #[[ADD]](%i3, %i4)
|
// CHECK-NEXT: %2 = dim %0, 1 : memref<?x?x?x?xf32>
|
||||||
|
// CHECK-NEXT: %3 = dim %0, 2 : memref<?x?x?x?xf32>
|
||||||
|
// CHECK-NEXT: %4 = dim %0, 3 : memref<?x?x?x?xf32>
|
||||||
|
// CHECK: %5 = alloc() : memref<5x4x3xf32>
|
||||||
|
// CHECK-NEXT: %6 = vector_type_cast %5 : memref<5x4x3xf32>, memref<1xvector<5x4x3xf32>>
|
||||||
|
// CHECK-NEXT: for %i4 = 0 to 3 {
|
||||||
// CHECK-NEXT: for %i5 = 0 to 4 {
|
// CHECK-NEXT: for %i5 = 0 to 4 {
|
||||||
// CHECK-NEXT: for %i6 = 0 to 3 {
|
// CHECK-NEXT: for %i6 = 0 to 5 {
|
||||||
// CHECK-NEXT: %4 = affine_apply #[[ADD]](%i0, %i6)
|
// CHECK-NEXT: %7 = affine_apply #[[ADD]](%i0, %i4)
|
||||||
// CHECK-NEXT: %5 = load %0[%4, %i1, %i2, %3] : memref<?x?x?x?xf32>
|
// CHECK-NEXT: %8 = cmpi "slt", %7, %c0 : index
|
||||||
// CHECK-NEXT: store %5, %1[%i4, %i5, %i6] : memref<5x4x3xf32>
|
// CHECK-NEXT: %9 = affine_apply #[[ADD]](%i0, %i4)
|
||||||
|
// CHECK-NEXT: %10 = cmpi "slt", %9, %1 : index
|
||||||
|
// CHECK-NEXT: %11 = affine_apply #[[ADD]](%i0, %i4)
|
||||||
|
// CHECK-NEXT: %12 = affine_apply #[[SUB]](%1, %c1)
|
||||||
|
// CHECK-NEXT: %13 = select %10, %11, %12 : index
|
||||||
|
// CHECK-NEXT: %14 = select %8, %c0, %13 : index
|
||||||
|
// CHECK-NEXT: %15 = affine_apply #[[ADD]](%i3, %i6)
|
||||||
|
// CHECK-NEXT: %16 = cmpi "slt", %15, %c0 : index
|
||||||
|
// CHECK-NEXT: %17 = affine_apply #[[ADD]](%i3, %i6)
|
||||||
|
// CHECK-NEXT: %18 = cmpi "slt", %17, %4 : index
|
||||||
|
// CHECK-NEXT: %19 = affine_apply #[[ADD]](%i3, %i6)
|
||||||
|
// CHECK-NEXT: %20 = affine_apply #[[SUB]](%4, %c1)
|
||||||
|
// CHECK-NEXT: %21 = select %18, %19, %20 : index
|
||||||
|
// CHECK-NEXT: %22 = select %16, %c0, %21 : index
|
||||||
|
// CHECK-NEXT: %23 = load %0[%14, %i1, %i2, %22] : memref<?x?x?x?xf32>
|
||||||
|
// CHECK-NEXT: store %23, %5[%i6, %i5, %i4] : memref<5x4x3xf32>
|
||||||
// CHECK-NEXT: }
|
// CHECK-NEXT: }
|
||||||
// CHECK-NEXT: }
|
// CHECK-NEXT: }
|
||||||
// CHECK-NEXT: }
|
// CHECK-NEXT: }
|
||||||
// CHECK-NEXT: %6 = load %2[%c0] : memref<1xvector<5x4x3xf32>>
|
// CHECK-NEXT: %24 = load %6[%c0] : memref<1xvector<5x4x3xf32>>
|
||||||
// CHECK-NEXT: dealloc %1 : memref<5x4x3xf32>
|
// CHECK-NEXT: dealloc %5 : memref<5x4x3xf32>
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: return
|
||||||
|
// CHECK-NEXT:}
|
||||||
|
%A = alloc (%M, %N, %O, %P) : memref<?x?x?x?xf32, 0>
|
||||||
for %i0 = 0 to %M step 3 {
|
for %i0 = 0 to %M step 3 {
|
||||||
for %i1 = 0 to %N {
|
for %i1 = 0 to %N {
|
||||||
for %i2 = 0 to %O {
|
for %i2 = 0 to %O {
|
||||||
|
@ -33,28 +62,64 @@ mlfunc @materialize_read(%M : index, %N : index, %O : index, %P : index) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL:mlfunc @materialize_write(%arg0: index, %arg1: index, %arg2: index, %arg3: index) {
|
||||||
mlfunc @materialize_write(%M: index, %N: index, %O: index, %P: index) {
|
mlfunc @materialize_write(%M: index, %N: index, %O: index, %P: index) {
|
||||||
%A = alloc (%M, %N, %O, %P) : memref<?x?x?x?xf32, 0>
|
// CHECK-NEXT: %0 = alloc(%arg0, %arg1, %arg2, %arg3) : memref<?x?x?x?xf32>
|
||||||
%f1 = constant splat<vector<5x4x3xf32>, 1.000000e+00> : vector<5x4x3xf32>
|
// CHECK-NEXT: %cst = constant splat<vector<5x4x3xf32>, 1.000000e+00> : vector<5x4x3xf32>
|
||||||
// CHECK: for %i0 = 0 to %arg0 step 3 {
|
// CHECK-NEXT: for %i0 = 0 to %arg0 step 3 {
|
||||||
// CHECK-NEXT: for %i1 = 0 to %arg1 step 4 {
|
// CHECK-NEXT: for %i1 = 0 to %arg1 step 4 {
|
||||||
// CHECK-NEXT: for %i2 = 0 to %arg2 {
|
// CHECK-NEXT: for %i2 = 0 to %arg2 {
|
||||||
// CHECK-NEXT: for %i3 = 0 to %arg3 step 5 {
|
// CHECK-NEXT: for %i3 = 0 to %arg3 step 5 {
|
||||||
// CHECK-NEXT: %1 = alloc() : memref<5x4x3xf32>
|
// CHECK-NEXT: %c0 = constant 0 : index
|
||||||
// CHECK-NEXT: %2 = vector_type_cast %1 : memref<5x4x3xf32>, memref<1xvector<5x4x3xf32>>
|
// CHECK-NEXT: %c1 = constant 1 : index
|
||||||
// CHECK-NEXT: store %cst, %2[%c0] : memref<1xvector<5x4x3xf32>>
|
// CHECK: %1 = dim %0, 0 : memref<?x?x?x?xf32>
|
||||||
// CHECK-NEXT: for %i4 = 0 to 5 {
|
// CHECK-NEXT: %2 = dim %0, 1 : memref<?x?x?x?xf32>
|
||||||
// CHECK-NEXT: %3 = affine_apply #[[ADD]](%i3, %i4)
|
// CHECK-NEXT: %3 = dim %0, 2 : memref<?x?x?x?xf32>
|
||||||
|
// CHECK-NEXT: %4 = dim %0, 3 : memref<?x?x?x?xf32>
|
||||||
|
// CHECK: %5 = alloc() : memref<5x4x3xf32>
|
||||||
|
// CHECK-NEXT: %6 = vector_type_cast %5 : memref<5x4x3xf32>, memref<1xvector<5x4x3xf32>>
|
||||||
|
// CHECK-NEXT: store %cst, %6[%c0] : memref<1xvector<5x4x3xf32>>
|
||||||
|
// CHECK-NEXT: for %i4 = 0 to 3 {
|
||||||
// CHECK-NEXT: for %i5 = 0 to 4 {
|
// CHECK-NEXT: for %i5 = 0 to 4 {
|
||||||
// CHECK-NEXT: %4 = affine_apply #[[ADD]](%i1, %i5)
|
// CHECK-NEXT: for %i6 = 0 to 5 {
|
||||||
// CHECK-NEXT: for %i6 = 0 to 3 {
|
// CHECK-NEXT: %7 = load %5[%i6, %i5, %i4] : memref<5x4x3xf32>
|
||||||
// CHECK-NEXT: %5 = affine_apply #[[ADD]](%i0, %i6)
|
// CHECK-NEXT: %8 = affine_apply #[[ADD]](%i0, %i4)
|
||||||
// CHECK-NEXT: %6 = load %1[%i4, %i5, %i6] : memref<5x4x3xf32>
|
// CHECK-NEXT: %9 = cmpi "slt", %8, %c0 : index
|
||||||
// CHECK-NEXT: store %6, %0[%5, %4, %i2, %3] : memref<?x?x?x?xf32>
|
// CHECK-NEXT: %10 = affine_apply #[[ADD]](%i0, %i4)
|
||||||
|
// CHECK-NEXT: %11 = cmpi "slt", %10, %1 : index
|
||||||
|
// CHECK-NEXT: %12 = affine_apply #[[ADD]](%i0, %i4)
|
||||||
|
// CHECK-NEXT: %13 = affine_apply #[[SUB]](%1, %c1)
|
||||||
|
// CHECK-NEXT: %14 = select %11, %12, %13 : index
|
||||||
|
// CHECK-NEXT: %15 = select %9, %c0, %14 : index
|
||||||
|
// CHECK-NEXT: %16 = affine_apply #[[ADD]](%i1, %i5)
|
||||||
|
// CHECK-NEXT: %17 = cmpi "slt", %16, %c0 : index
|
||||||
|
// CHECK-NEXT: %18 = affine_apply #[[ADD]](%i1, %i5)
|
||||||
|
// CHECK-NEXT: %19 = cmpi "slt", %18, %2 : index
|
||||||
|
// CHECK-NEXT: %20 = affine_apply #[[ADD]](%i1, %i5)
|
||||||
|
// CHECK-NEXT: %21 = affine_apply #[[SUB]](%2, %c1)
|
||||||
|
// CHECK-NEXT: %22 = select %19, %20, %21 : index
|
||||||
|
// CHECK-NEXT: %23 = select %17, %c0, %22 : index
|
||||||
|
// CHECK-NEXT: %24 = affine_apply #[[ADD]](%i3, %i6)
|
||||||
|
// CHECK-NEXT: %25 = cmpi "slt", %24, %c0 : index
|
||||||
|
// CHECK-NEXT: %26 = affine_apply #[[ADD]](%i3, %i6)
|
||||||
|
// CHECK-NEXT: %27 = cmpi "slt", %26, %4 : index
|
||||||
|
// CHECK-NEXT: %28 = affine_apply #[[ADD]](%i3, %i6)
|
||||||
|
// CHECK-NEXT: %29 = affine_apply #[[SUB]](%4, %c1)
|
||||||
|
// CHECK-NEXT: %30 = select %27, %28, %29 : index
|
||||||
|
// CHECK-NEXT: %31 = select %25, %c0, %30 : index
|
||||||
|
// CHECK-NEXT: store %7, %0[%15, %23, %i2, %31] : memref<?x?x?x?xf32>
|
||||||
// CHECK-NEXT: }
|
// CHECK-NEXT: }
|
||||||
// CHECK-NEXT: }
|
// CHECK-NEXT: }
|
||||||
// CHECK-NEXT: }
|
// CHECK-NEXT: }
|
||||||
// CHECK-NEXT: dealloc %1 : memref<5x4x3xf32>
|
// CHECK-NEXT: dealloc %5 : memref<5x4x3xf32>
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: }
|
||||||
|
// CHECK-NEXT: return
|
||||||
|
// CHECK-NEXT:}
|
||||||
|
%A = alloc (%M, %N, %O, %P) : memref<?x?x?x?xf32, 0>
|
||||||
|
%f1 = constant splat<vector<5x4x3xf32>, 1.000000e+00> : vector<5x4x3xf32>
|
||||||
for %i0 = 0 to %M step 3 {
|
for %i0 = 0 to %M step 3 {
|
||||||
for %i1 = 0 to %N step 4 {
|
for %i1 = 0 to %N step 4 {
|
||||||
for %i2 = 0 to %O {
|
for %i2 = 0 to %O {
|
||||||
|
|
Loading…
Reference in New Issue