forked from OSchip/llvm-project
73f5c9c380
This CL introduces a simple set of Embedded Domain-Specific Components (EDSCs) in MLIR components: 1. a `Type` system of shell classes that closely matches the MLIR type system. These types are subdivided into `Bindable` leaf expressions and non-bindable `Expr` expressions; 2. an `MLIREmitter` class whose purpose is to: a. maintain a map of `Bindable` leaf expressions to concrete SSAValue*; b. provide helper functionality to specify bindings of `Bindable` classes to SSAValue* while verifying comformable types; c. traverse the `Expr` and emit the MLIR. This is used on a concrete example to implement MemRef load/store with clipping in the LowerVectorTransfer pass. More specifically, the following pseudo-C++ code: ```c++ MLFuncBuilder *b = ...; Location location = ...; Bindable zero, one, expr, size; // EDSL expression auto access = select(expr < zero, zero, select(expr < size, expr, size - one)); auto ssaValue = MLIREmitter(b) .bind(zero, ...) .bind(one, ...) .bind(expr, ...) .bind(size, ...) .emit(location, access); ``` is used to emit all the MLIR for a clipped MemRef access. This simple EDSL can easily be extended to more powerful patterns and should serve as the counterpart to pattern matchers (and could potentially be unified once we get enough experience). In the future, most of this code should be TableGen'd but for now it has concrete valuable uses: make MLIR programmable in a declarative fashion. This CL also adds Stmt, proper supporting free functions and rewrites VectorTransferLowering fully using EDSCs. The code for creating the EDSCs emitting a VectorTransferReadOp as loops with clipped loads is: ```c++ Stmt block = Block({ tmpAlloc = alloc(tmpMemRefType), vectorView = vector_type_cast(tmpAlloc, vectorMemRefType), ForNest(ivs, lbs, ubs, steps, { scalarValue = load(scalarMemRef, accessInfo.clippedScalarAccessExprs), store(scalarValue, tmpAlloc, accessInfo.tmpAccessExprs), }), vectorValue = load(vectorView, zero), tmpDealloc = dealloc(tmpAlloc.getLHS())}); emitter.emitStmt(block); ``` where `accessInfo.clippedScalarAccessExprs)` is created with: ```c++ select(i + ii < zero, zero, select(i + ii < N, i + ii, N - one)); ``` The generated MLIR resembles: ```mlir %1 = dim %0, 0 : memref<?x?x?x?xf32> %2 = dim %0, 1 : memref<?x?x?x?xf32> %3 = dim %0, 2 : memref<?x?x?x?xf32> %4 = dim %0, 3 : memref<?x?x?x?xf32> %5 = alloc() : memref<5x4x3xf32> %6 = vector_type_cast %5 : memref<5x4x3xf32>, memref<1xvector<5x4x3xf32>> for %i4 = 0 to 3 { for %i5 = 0 to 4 { for %i6 = 0 to 5 { %7 = affine_apply #map0(%i0, %i4) %8 = cmpi "slt", %7, %c0 : index %9 = affine_apply #map0(%i0, %i4) %10 = cmpi "slt", %9, %1 : index %11 = affine_apply #map0(%i0, %i4) %12 = affine_apply #map1(%1, %c1) %13 = select %10, %11, %12 : index %14 = select %8, %c0, %13 : index %15 = affine_apply #map0(%i3, %i6) %16 = cmpi "slt", %15, %c0 : index %17 = affine_apply #map0(%i3, %i6) %18 = cmpi "slt", %17, %4 : index %19 = affine_apply #map0(%i3, %i6) %20 = affine_apply #map1(%4, %c1) %21 = select %18, %19, %20 : index %22 = select %16, %c0, %21 : index %23 = load %0[%14, %i1, %i2, %22] : memref<?x?x?x?xf32> store %23, %5[%i6, %i5, %i4] : memref<5x4x3xf32> } } } %24 = load %6[%c0] : memref<1xvector<5x4x3xf32>> dealloc %5 : memref<5x4x3xf32> ``` In particular notice that only 3 out of the 4-d accesses are clipped: this corresponds indeed to the number of dimensions in the super-vector. This CL also addresses the cleanups resulting from the review of the prevous CL and performs some refactoring to simplify the abstraction. PiperOrigin-RevId: 227367414 |
||
---|---|---|
mlir |