From 1ca772ed951e6412ef006459b56ae9a21691a97c Mon Sep 17 00:00:00 2001 From: Christopher Bate Date: Tue, 17 May 2022 17:54:29 -0600 Subject: [PATCH] [MLIR][GPU] Add NvGpu mma.sync path to the VectorToGPU pass This changes adds the option to lower to NvGpu dialect ops during the VectorToGPU convsersion pass. Because this transformation reuses existing VectorToGPU logic, a seperate VectorToNvGpu conversion pass is not created. The option `use-nvgpu` is added to the VectorToGPU pass. When this is true, the pass will attempt to convert slices rooted at `vector.contract` operations into `nvgpu.mma.sync` ops, and `vector.transfer_read` ops are converted to either `nvgpu.ldmatrix` or one or more `vector.load` operations. The specific data loaded will depend on the thread id within a subgroup (warp). These index calculations depend on data type and shape of the MMA op according to the downstream PTX specification. The code for supporting these details is separated into `NvGpuSupport.cpp|h`. Differential Revision: https://reviews.llvm.org/D122940 --- mlir/include/mlir/Conversion/Passes.td | 9 +- .../mlir/Conversion/VectorToGPU/VectorToGPU.h | 15 +- mlir/lib/Conversion/PassDetail.h | 4 + .../lib/Conversion/VectorToGPU/CMakeLists.txt | 1 + .../Conversion/VectorToGPU/NvGpuSupport.cpp | 327 ++++++++++++++ .../lib/Conversion/VectorToGPU/NvGpuSupport.h | 100 +++++ .../Conversion/VectorToGPU/VectorToGPU.cpp | 404 +++++++++++++++++- .../vector-to-mma-ops-mma-sync.mlir | 349 +++++++++++++++ 8 files changed, 1181 insertions(+), 28 deletions(-) create mode 100644 mlir/lib/Conversion/VectorToGPU/NvGpuSupport.cpp create mode 100644 mlir/lib/Conversion/VectorToGPU/NvGpuSupport.h create mode 100644 mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 41e7b29f15d0..6d9863e72348 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -851,8 +851,13 @@ def ConvertVectorToGPU : Pass<"convert-vector-to-gpu"> { "dialect"; let constructor = "mlir::createConvertVectorToGPUPass()"; let dependentDialects = [ - "memref::MemRefDialect", - "gpu::GPUDialect" + "memref::MemRefDialect", "gpu::GPUDialect", "AffineDialect", + "vector::VectorDialect", "nvgpu::NVGPUDialect" + ]; + + let options = [ + Option<"useNvGpu", "use-nvgpu", "bool", /*default=*/"false", + "convert to NvGPU ops instead of GPU dialect ops"> ]; } diff --git a/mlir/include/mlir/Conversion/VectorToGPU/VectorToGPU.h b/mlir/include/mlir/Conversion/VectorToGPU/VectorToGPU.h index 266fa0eac4c4..1ba5b3f90d9a 100644 --- a/mlir/include/mlir/Conversion/VectorToGPU/VectorToGPU.h +++ b/mlir/include/mlir/Conversion/VectorToGPU/VectorToGPU.h @@ -17,16 +17,25 @@ class Pass; class RewritePatternSet; /// Patterns to transform vector ops into a canonical form to convert to MMA -/// matrix operations. -void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns); +/// matrix operations. If `useNvGpu` is true, then the patterns will populated +/// will prepare for conversion to `nvgpu` mma operations rather than the `gpu` +/// dialect WMMA operations. +void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns, + bool useNvGpu = false); /// Convert vector ops to MMA matrix operations nested under `rootOp`. This will /// convert slice of operations that can be legally converted to MMA operations. /// The rest of the vector operations are left untouched. void convertVectorToMMAOps(Operation *rootOp); +/// Convert vector ops ops nested under `rootOp` to vector and GPU operaitons +/// compatible with the `nvvm.mma.sync` lowering path. This will convert a slice +/// of operations that can be legally lowered on this path while the rest of +/// the vector operations are left untouched. +LogicalResult convertVectorToNVVMCompatibleMMASync(Operation *rootOp); + /// Convert from vector to GPU ops. -std::unique_ptr createConvertVectorToGPUPass(); +std::unique_ptr createConvertVectorToGPUPass(bool useNvGpu = false); } // namespace mlir diff --git a/mlir/lib/Conversion/PassDetail.h b/mlir/lib/Conversion/PassDetail.h index e05004061dd4..530e156024fd 100644 --- a/mlir/lib/Conversion/PassDetail.h +++ b/mlir/lib/Conversion/PassDetail.h @@ -55,6 +55,10 @@ namespace LLVM { class LLVMDialect; } // namespace LLVM +namespace nvgpu { +class NVGPUDialect; +} + namespace NVVM { class NVVMDialect; } // namespace NVVM diff --git a/mlir/lib/Conversion/VectorToGPU/CMakeLists.txt b/mlir/lib/Conversion/VectorToGPU/CMakeLists.txt index 06758c5fe126..778f2c42eebe 100644 --- a/mlir/lib/Conversion/VectorToGPU/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToGPU/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_conversion_library(MLIRVectorToGPU VectorToGPU.cpp + NvGpuSupport.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToGPU diff --git a/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.cpp b/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.cpp new file mode 100644 index 000000000000..a2820c3e88f8 --- /dev/null +++ b/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.cpp @@ -0,0 +1,327 @@ +//===- NvGpuSupport.cpp - MLIR Vector to GPU lowering support --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file provides utilities to assist in the lowering of Vector operations +// to NvGPU dialect MMA operations. +// +//===----------------------------------------------------------------------===// + +#include "NvGpuSupport.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/NVGPU/NVGPUDialect.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" + +namespace mlir { +namespace nvgpu { +namespace { + +/// There are always 4 threads per [128|256|512] bit row. +constexpr int64_t kThreadsPerRow = 4; + +constexpr int64_t kNumRowsPerTile = 8; + +bool isAccumulatorOrResult(MatMulOperandRole operandType) { + return operandType == MatMulOperandRole::C; +} + +/// Returns the number of registers which compose a matrix fragment held by a +/// single thread. +int64_t inferNumRegistersPerMatrixFragment(const WarpMatrixInfo &type) { + int64_t lineSize = inferTileWidthInBits(type); + auto shape = type.vectorType.getShape(); + return (shape[0] / kNumRowsPerTile) * + (shape[1] * type.vectorType.getElementType().getIntOrFloatBitWidth()) / + lineSize; +} + +/// Returns the number of 8 x [128|256|512] bit tiles that compose the given +/// operand shape. +std::array getTileShape(ArrayRef operandShape, + Type elementType, int64_t lineSizeBits) { + // For each 8x128bit square, a thread is responsible for one 32bit register. + return {operandShape[0] / kNumRowsPerTile, + (operandShape[1] * elementType.getIntOrFloatBitWidth()) / + lineSizeBits}; +} + +} // namespace + +FailureOr getWarpMatrixInfo(Operation *op) { + WarpMatrixInfo info; + + // Determine the vector type. + if (vector::TransferWriteOp writeOp = dyn_cast(op)) { + info.vectorType = writeOp.getVectorType(); + } else if (isa(op)) { + info.vectorType = op->getResult(0).getType().cast(); + } else { + return op->emitError() + << "unhandled operation type in nvgpu.mma.sync conversion path"; + } + + // Determine the operand role. We assume it is an accumulator/result unless it + // is directly consumed by a `vector.contract` op. + info.operandRole = MatMulOperandRole::C; + for (Operation *user : op->getUsers()) { + auto contract = dyn_cast(user); + if (!contract) + continue; + if (contract.getLhs() == op->getResult(0)) { + info.operandRole = MatMulOperandRole::A; + break; + } + if (contract.getRhs() == op->getResult(0)) { + info.operandRole = MatMulOperandRole::B; + break; + } + } + return info; +} + +int64_t inferTileWidthInBits(const WarpMatrixInfo &type) { + bool isAcc = isAccumulatorOrResult(type.operandRole); + Type elType = type.vectorType.getElementType(); + if (isAcc && elType.getIntOrFloatBitWidth() == 32) { + return 256; + } + if (elType.getIntOrFloatBitWidth() == 64) { + return isAcc ? 512 : 256; + } + return 128; +} + +FailureOr +getMmaSyncRegisterType(const WarpMatrixInfo &type) { + MLIRContext *ctx = type.vectorType.getContext(); + const bool isAccum = isAccumulatorOrResult(type.operandRole); + + Type elType = type.vectorType.getElementType(); + if (elType.isF16()) { + return FragmentElementInfo{ + LLVM::getFixedVectorType(Float16Type::get(ctx), 2), 2, 32, + inferNumRegistersPerMatrixFragment(type)}; + } + + // f64 operand + Type f64Ty = Float64Type::get(ctx); + if (elType.isF64()) { + return isAccum + ? FragmentElementInfo{LLVM::getFixedVectorType(f64Ty, 2), 2, 128, + inferNumRegistersPerMatrixFragment(type)} + : FragmentElementInfo{f64Ty, 1, 64, + inferNumRegistersPerMatrixFragment(type)}; + } + + // int8 operand + if (elType.isInteger(8)) { + return FragmentElementInfo{ + LLVM::getFixedVectorType(IntegerType::get(ctx, 8), 4), 4, 32, + inferNumRegistersPerMatrixFragment(type)}; + } + // Integer 32bit acc operands + if (elType.isInteger(32)) { + return FragmentElementInfo{ + LLVM::getFixedVectorType(IntegerType::get(ctx, 32), 2), 2, 64, + inferNumRegistersPerMatrixFragment(type)}; + } + + // Floating point 32bit operands + if (elType.isF32()) { + Type f32Ty = Float32Type::get(ctx); + return isAccum + ? FragmentElementInfo{LLVM::getFixedVectorType(f32Ty, 2), 2, 64, + inferNumRegistersPerMatrixFragment(type)} + : FragmentElementInfo{f32Ty, 1, 32, + inferNumRegistersPerMatrixFragment(type)}; + } + return failure(); +} + +static AffineMap getRegisterIndexToTileOffsetMap(int64_t lineSize, + Type elementType, + ArrayRef operandShape, + bool isAccumulator, + int64_t elementsPerRegister, + AffineExpr logicalValueId) { + const int64_t elementsPerLine = + lineSize / elementType.getIntOrFloatBitWidth(); + const std::array num8x128bTiles = + getTileShape(operandShape, elementType, lineSize); + AffineExpr registerIdx = logicalValueId.floorDiv(elementsPerRegister); + return AffineMap::get( + 2, 0, + {(registerIdx % num8x128bTiles[0]) * 8, + (registerIdx.floorDiv(num8x128bTiles[0])) * elementsPerLine}, + elementType.getContext()); +} + +FailureOr +getLaneIdAndValueIdToOperandCoord(Location loc, OpBuilder &builder, + const WarpMatrixInfo &fragmentType) { + Type elementType = fragmentType.vectorType.getElementType(); + ArrayRef operandShape = fragmentType.vectorType.getShape(); + FailureOr regInfo = + getMmaSyncRegisterType(fragmentType); + if (failed(regInfo)) + return failure(); + + const int64_t elementBitWidth = elementType.getIntOrFloatBitWidth(); + const int64_t elementsPerRegister = + regInfo->registerWidthBits / elementBitWidth; + const int64_t lineSize = inferTileWidthInBits(fragmentType); + + AffineExpr laneId, logicalValueIdDim; + bindDims(builder.getContext(), laneId, logicalValueIdDim); + + // Determine what register logicalValueId corresponds to. Use that as a + // linear index into the coordinate mapping `index -> (tile row, tile col)`. + AffineMap registerIndexToTileCoord = getRegisterIndexToTileOffsetMap( + lineSize, elementType, operandShape, + isAccumulatorOrResult(fragmentType.operandRole), elementsPerRegister, + logicalValueIdDim); + + auto makeMap = [&](ArrayRef dimExprs) -> AffineMap { + return AffineMap::get(2, 0, dimExprs, builder.getContext()); + }; + + auto tileRow = registerIndexToTileCoord.getResult(0); + auto tileCol = registerIndexToTileCoord.getResult(1); + return makeMap({tileRow + laneId.floorDiv(kThreadsPerRow), + tileCol + (laneId % kThreadsPerRow) * elementsPerRegister + + (logicalValueIdDim % elementsPerRegister)}); +} + +FailureOr getLdMatrixParams(const WarpMatrixInfo &type, + bool transpose) { + LdMatrixParams params; + Type elType = type.vectorType.getElementType(); + params.fragmentType = type.vectorType; + if (type.operandRole == MatMulOperandRole::A || + type.operandRole == MatMulOperandRole::C) { + params.targetLayout = NVVM::MMALayout::row; + } else { + params.targetLayout = NVVM::MMALayout::col; + } + ArrayRef shape = type.vectorType.getShape(); + params.contiguousDimType = + transpose ? IteratorType::Parallel : IteratorType::Reduction; + + if (params.targetLayout == NVVM::MMALayout::row) { + params.numTiles = (shape[0] / kNumRowsPerTile) * + ((shape[1] * elType.getIntOrFloatBitWidth()) / 128); + } else { + params.numTiles = (shape[1] / kNumRowsPerTile) * + ((shape[0] * elType.getIntOrFloatBitWidth()) / 128); + } + + if (params.numTiles == 0) + return failure(); + + return params; +} + +FailureOr +getLaneIdToLdMatrixMatrixCoord(Location loc, OpBuilder &builder, + const LdMatrixParams ¶ms) { + // One thread per 128b row. + const int64_t kNumThreadsPerTile = kNumRowsPerTile; + const int bitsPerElement = static_cast( + params.fragmentType.getElementType().getIntOrFloatBitWidth()); + const int kElementsPer128b = (128 / bitsPerElement); + ArrayRef operandShape = params.fragmentType.getShape(); + AffineExpr d0 = getAffineDimExpr(0, builder.getContext()); + + auto makeMap = [&](ArrayRef dimExprs) -> AffineMap { + return AffineMap::get(1, 0, dimExprs, builder.getContext()); + }; + + // This case corresponds to row-major A|C or col-major B operands. + if (params.contiguousDimType == IteratorType::Reduction) { + AffineExpr row = d0 % (operandShape[0]); + AffineExpr col = d0.floorDiv(operandShape[0]) * (kElementsPer128b); + return makeMap({row, col}); + } + + // This case Corresponds to col-major A|C or row-major B operands. The + // operandShape given is already pre-transposed (e.g. 8x16 = KxN). + if (params.contiguousDimType == IteratorType::Parallel) { + const int64_t num8x128bCols = (operandShape[0] * bitsPerElement) / 128; + // Threads are assigned in groups of 8 first across columns, then to + // rows. This is transpose of what `ldmatrix` expects, but when + // `ldmatrix` gets the `.trans` qualifier, final the effect will be to + // transpose just the blocks. + auto groupIdx = d0.floorDiv(kNumThreadsPerTile); + auto tileCol = (groupIdx % num8x128bCols); + auto tileRow = groupIdx.floorDiv(num8x128bCols); + return makeMap({tileCol * kElementsPer128b, + tileRow * kNumRowsPerTile + (d0 % kNumRowsPerTile)}); + } + return failure(); +} + +LogicalResult +PrepareContractToGPUMMASync::matchAndRewrite(vector::ContractionOp op, + PatternRewriter &rewriter) const { + Location loc = op.getLoc(); + Value lhs = op.getLhs(); + Value rhs = op.getRhs(); + Value res = op.getAcc(); + + // Set up the parallel/reduction structure in right form. + using MapList = ArrayRef>; + auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; + AffineExpr m; + AffineExpr n; + AffineExpr k; + bindDims(rewriter.getContext(), m, n, k); + static constexpr std::array perm = {1, 0}; + auto iteratorTypes = op.getIteratorTypes().getValue(); + SmallVector maps = op.getIndexingMaps(); + if (iteratorTypes.size() != 3) + return failure(); + if (!(isParallelIterator(iteratorTypes[0]) && + isParallelIterator(iteratorTypes[1]) && + isReductionIterator(iteratorTypes[2]))) + return failure(); + + // The canonical form is "TNT" = A row-major, B col-major, C row-major. + const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}}); + if (maps == canonicalForm) { + return failure(); + } + if (maps == infer({{m, k}, {k, n}, {m, n}})) { + rhs = rewriter.create(loc, rhs, perm); + } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { + lhs = rewriter.create(loc, lhs, perm); + } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { + rhs = rewriter.create(loc, rhs, perm); + lhs = rewriter.create(loc, lhs, perm); + } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { + std::swap(rhs, lhs); + rhs = rewriter.create(loc, rhs, perm); + lhs = rewriter.create(loc, lhs, perm); + } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { + std::swap(rhs, lhs); + rhs = rewriter.create(loc, rhs, perm); + } else if (maps == infer({{m, k}, {k, n}, {n, m}})) { + std::swap(lhs, rhs); + lhs = rewriter.create(loc, lhs, perm); + } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { + std::swap(lhs, rhs); + } else { + return failure(); + } + rewriter.replaceOpWithNewOp( + op, lhs, rhs, res, rewriter.getAffineMapArrayAttr(canonicalForm), + op.getIteratorTypes()); + return success(); +} + +} // namespace nvgpu +} // namespace mlir \ No newline at end of file diff --git a/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.h b/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.h new file mode 100644 index 000000000000..9902faa835a6 --- /dev/null +++ b/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.h @@ -0,0 +1,100 @@ +//===- NvvmMMASupport.h - MLIR Vector to GPU lowering support --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file provides utilities to assist in the lowering of Vector operations +// to GPU dialect MMA operations. +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_CONVERSION_VECTORTOGPU_NVGPUSUPPORT_H +#define MLIR_CONVERSION_VECTORTOGPU_NVGPUSUPPORT_H + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Types.h" + +namespace mlir { +namespace nvgpu { + +enum class MatMulOperandRole : int32_t { A = 0, B, C }; + +/// Collects information about a warp-level matrix operand represented by a +/// VectorType. +struct WarpMatrixInfo { + VectorType vectorType; + MatMulOperandRole operandRole; +}; + +/// Given an op that operates on a VectorType representing a warp-level matrix +/// operand, the function returns a struct containing relevant type information. +FailureOr getWarpMatrixInfo(Operation *op); + +/// Returns the number of bits in a single tile row. It is either 128, 256, or +/// 512 bits depending on the data type and` whether the operand is an +/// accumulator/result operand +int64_t inferTileWidthInBits(const WarpMatrixInfo &type); + +/// Specifies information about the registers which compose a matrix fragment +/// according to the PTX documentation. +struct FragmentElementInfo { + Type registerLLVMType; + int64_t elementsPerRegister; + int64_t registerWidthBits; + int64_t numRegistersPerFragment; +}; + +/// Returns a FragmentElementInfo struct describing the register types for the +/// given matrix fragment type. +FailureOr +getMmaSyncRegisterType(const WarpMatrixInfo &type); + +/// Returns an AffineMap which maps a two dimensions representing (laneId, +/// logicalValueId) and returns two results representing offsets within a +/// matrix operand. The offsets point to the values the thread is responsible +/// for (AKA the matrix fragment values) during a warp-collective matrix +/// operation. For a visual reference of this LaneId -> (row, col) mapping, +/// please see NVIDIA's PTX documentation: +/// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-mma +FailureOr +getLaneIdAndValueIdToOperandCoord(Location loc, OpBuilder &builder, + const WarpMatrixInfo &fragmentType); + +struct LdMatrixParams { + VectorType fragmentType; + bool isAccum; + int64_t numTiles; + IteratorType contiguousDimType; + NVVM::MMALayout targetLayout; +}; + +FailureOr getLdMatrixParams(const WarpMatrixInfo &type, + bool transpose); +/// Returns an AffineMap which maps a single dimension representing the laneId +/// to two results representing offsets within the matrix operand that should +/// be the pointer locations a thread should pass to the ldmatrix instruction. +FailureOr +getLaneIdToLdMatrixMatrixCoord(Location loc, OpBuilder &builder, + const LdMatrixParams ¶ms); + +// Transform contract into (m, k)x(n, k)x(m, n) form so that it can be converted +// to MMA matmul. +struct PrepareContractToGPUMMASync + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ContractionOp op, + PatternRewriter &rewriter) const override; +}; + +} // namespace nvgpu +} // namespace mlir + +#endif // MLIR_CONVERSION_VECTORTOGPU_NVGPUSUPPORT_H diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp index 9ed1c3483c11..a6e122c38031 100644 --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -12,6 +12,7 @@ #include +#include "NvGpuSupport.h" #include "mlir/Conversion/VectorToGPU/VectorToGPU.h" #include "../PassDetail.h" @@ -19,6 +20,7 @@ #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/NVGPU/NVGPUDialect.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" @@ -27,11 +29,39 @@ #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" +#include "llvm/ADT/TypeSwitch.h" using namespace mlir; +/// For a vector TransferOpType `xferOp`, an empty `indices` vector, and an +/// AffineMap representing offsets to apply to indices, the function fills +/// `indices` with the original indices plus the offsets. The offsets are +/// applied by taking into account the permutation map of the transfer op. If +/// the `offsetMap` has dimension placeholders, those should be provided in +/// `dimValues`. +template +static void getXferIndices(OpBuilder &b, TransferOpType xferOp, + AffineMap offsetMap, ArrayRef dimValues, + SmallVector &indices) { + indices.append(xferOp.getIndices().begin(), xferOp.getIndices().end()); + Location loc = xferOp.getLoc(); + unsigned offsetsIdx = 0; + for (auto expr : xferOp.getPermutationMap().getResults()) { + if (auto dim = expr.template dyn_cast()) { + Value prevIdx = indices[dim.getPosition()]; + SmallVector dims(dimValues.begin(), dimValues.end()); + dims.push_back(prevIdx); + AffineExpr d0 = b.getAffineDimExpr(offsetMap.getNumDims()); + indices[dim.getPosition()] = makeComposedAffineApply( + b, loc, d0 + offsetMap.getResult(offsetsIdx++), dims); + continue; + } + } +} + // Return true if the contract op can be convert to MMA matmul. -static bool contractSupportsMMAMatrixType(vector::ContractionOp contract) { +static bool contractSupportsMMAMatrixType(vector::ContractionOp contract, + bool useNvGpu) { if (llvm::size(contract.getMasks()) != 0) return false; @@ -47,7 +77,10 @@ static bool contractSupportsMMAMatrixType(vector::ContractionOp contract) { // The contract needs to represent a matmul to be able to convert to // MMAMatrix matmul. - if (contract.getIndexingMaps() != infer({{m, k}, {k, n}, {m, n}})) + if (!useNvGpu && + contract.getIndexingMaps() != infer({{m, k}, {k, n}, {m, n}})) + return false; + if (useNvGpu && contract.getIndexingMaps() != infer({{m, k}, {n, k}, {m, n}})) return false; return true; @@ -61,7 +94,7 @@ getMemrefConstantHorizontalStride(ShapedType type) { if (!memrefType) return false; // If the memref is 0 or 1D the horizontal stride is 0. - if(memrefType.getRank() < 2) + if (memrefType.getRank() < 2) return 0; int64_t offset = 0; SmallVector strides; @@ -75,7 +108,8 @@ getMemrefConstantHorizontalStride(ShapedType type) { } // Return true if the transfer op can be converted to a MMA matrix load. -static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) { +static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp, + bool useNvGpu) { if (readOp.getMask() || readOp.hasOutOfBoundsDim() || readOp.getVectorType().getRank() != 2) return false; @@ -87,9 +121,14 @@ static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) { AffineExpr zero = b.getAffineConstantExpr(0); auto broadcastInnerDim = AffineMap::get(map.getNumDims(), 0, {zero, innerDim}, readOp.getContext()); - // TODO: Support transpose once it is added to GPU dialect ops. - // For now we only support (d0, d1) -> (d0, d1) and (d0, d1) -> (0, d1). - return !(!map.isMinorIdentity() && map != broadcastInnerDim); + + if (!useNvGpu) { + // TODO: Support transpose once it is added to GPU dialect ops. + // For now we only support (d0, d1) -> (d0, d1) and (d0, d1) -> (0, d1). + return map.isMinorIdentity() || map == broadcastInnerDim; + } + + return true; } // Return true if the transfer op can be converted to a MMA matrix store. @@ -147,15 +186,15 @@ static bool elementwiseSupportsMMAMatrixType(Operation *op) { return convertElementwiseOpToMMA(op).hasValue(); } -static bool supportsMMaMatrixType(Operation *op) { +static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) { if (isa(op)) return true; if (auto transferRead = dyn_cast(op)) - return transferReadSupportsMMAMatrixType(transferRead); + return transferReadSupportsMMAMatrixType(transferRead, useNvGpu); if (auto transferWrite = dyn_cast(op)) return transferWriteSupportsMMAMatrixType(transferWrite); if (auto contract = dyn_cast(op)) - return contractSupportsMMAMatrixType(contract); + return contractSupportsMMAMatrixType(contract, useNvGpu); if (auto constant = dyn_cast(op)) return constantSupportsMMAMatrixType(constant); if (auto broadcast = dyn_cast(op)) @@ -203,7 +242,8 @@ static SetVector getSliceContract(Operation *op, // Analyze slice of operations based on convert op to figure out if the whole // slice can be converted to MMA operations. -static SetVector getOpToConvert(mlir::Operation *op) { +static SetVector getOpToConvert(mlir::Operation *op, + bool useNvGpu) { auto hasVectorDest = [](Operation *op) { return llvm::any_of(op->getResultTypes(), [](Type t) { return t.isa(); }); @@ -221,8 +261,9 @@ static SetVector getOpToConvert(mlir::Operation *op) { // If any instruction cannot use MMA matrix type drop the whole // chain. MMA matrix are stored in an opaque type so they cannot be used // by all operations. - if (llvm::any_of(dependentOps, - [](Operation *op) { return !supportsMMaMatrixType(op); })) + if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) { + return !supportsMMaMatrixType(op, useNvGpu); + })) return; opToConvert.insert(dependentOps.begin(), dependentOps.end()); }); @@ -351,7 +392,7 @@ static const char *inferFragType(OpTy op) { static void convertTransferReadOp(vector::TransferReadOp op, llvm::DenseMap &valueMapping) { assert(op.getTransferRank() > 0 && "unexpected 0-d transfer"); - assert(transferReadSupportsMMAMatrixType(op)); + assert(transferReadSupportsMMAMatrixType(op, /*useNvGpu=*/false)); Optional stride = getMemrefConstantHorizontalStride(op.getShapedType()); AffineMap map = op.getPermutationMap(); @@ -386,6 +427,250 @@ static void convertTransferWriteOp(vector::TransferWriteOp op, op.erase(); } +/// Returns the vector type which represents a matrix fragment. +static VectorType +getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo ®Info) { + SmallVector shape{regInfo.numRegistersPerFragment, + regInfo.elementsPerRegister}; + Type elType = regInfo.registerLLVMType; + if (auto vecType = elType.dyn_cast()) + elType = vecType.getElementType(); + return VectorType::get(shape, elType); +} + +/// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op. +static LogicalResult +convertConstantOpMmaSync(arith::ConstantOp op, + llvm::DenseMap &valueMapping) { + OpBuilder b(op); + FailureOr warpMatrixInfo = + nvgpu::getWarpMatrixInfo(op); + if (failed(warpMatrixInfo)) + return failure(); + + FailureOr regInfo = + nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); + if (failed(regInfo)) + return failure(); + + VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); + auto dense = op.getValue().dyn_cast(); + if (!dense) + return failure(); + Value result = b.create( + op.getLoc(), vectorType, + DenseElementsAttr::get(vectorType, dense.getSplatValue())); + valueMapping[op.getResult()] = result; + return success(); +} + +static LogicalResult +creatLdMatrixCompatibleLoads(vector::TransferReadOp op, OpBuilder &builder, + llvm::DenseMap &valueMapping) { + Location loc = op->getLoc(); + + FailureOr warpMatrixInfo = + nvgpu::getWarpMatrixInfo(op); + if (failed(warpMatrixInfo)) + return failure(); + + FailureOr regInfo = + nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); + if (failed(regInfo)) + return failure(); + + FailureOr params = nvgpu::getLdMatrixParams( + *warpMatrixInfo, + /*transpose=*/!op.getPermutationMap().isMinorIdentity()); + if (failed(params)) { + return op->emitError() + << "failed to convert vector.transfer_read to ldmatrix; this op " + "likely " + "should not be converted to a nvgpu.ldmatrix call."; + } + + // Adjust the load offset. + auto laneId = builder.create(loc); + FailureOr offsets = + nvgpu::getLaneIdToLdMatrixMatrixCoord(loc, builder, *params); + if (failed(offsets)) + return failure(); + + VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); + + SmallVector indices; + getXferIndices(builder, op, *offsets, {laneId}, + indices); + nvgpu::LdMatrixOp newOp = builder.create( + loc, vectorType, op.getSource(), indices, + !op.getPermutationMap().isMinorIdentity(), params->numTiles); + valueMapping[op] = newOp->getResult(0); + return success(); +} + +static LogicalResult +createNonLdMatrixLoads(vector::TransferReadOp op, OpBuilder &builder, + llvm::DenseMap &valueMapping) { + Location loc = op.getLoc(); + FailureOr warpMatrixInfo = + nvgpu::getWarpMatrixInfo(op); + if (failed(warpMatrixInfo)) + return failure(); + FailureOr regInfo = + nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); + if (failed(regInfo)) { + op->emitError() << "Failed to deduce register fragment type during " + "conversion to distributed non-ldmatrix compatible load"; + return failure(); + } + + NVVM::MMALayout targetLayout = + warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B + ? NVVM::MMALayout::col + : NVVM::MMALayout::row; + + Value laneId = builder.create(loc); + SmallVector elements; + + // This is the individual element type. + Type loadedElType = regInfo->registerLLVMType; + VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); + + Value fill = builder.create( + op.getLoc(), vectorType.getElementType(), + builder.getZeroAttr(vectorType.getElementType())); + Value result = builder.create(op.getLoc(), fill, vectorType); + + bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity(); + + // Vectorized loads. + if (!isTransposeLoad && targetLayout == NVVM::MMALayout::row) { + if (!loadedElType.isa()) { + loadedElType = VectorType::get({1}, loadedElType); + } + + for (int i = 0; i < vectorType.getShape()[0]; i++) { + FailureOr coords = nvgpu::getLaneIdAndValueIdToOperandCoord( + op.getLoc(), builder, *warpMatrixInfo); + if (failed(coords)) + return failure(); + Value logicalValueId = builder.create( + loc, builder.getIndexType(), + builder.getIndexAttr(i * regInfo->elementsPerRegister)); + SmallVector newIndices; + getXferIndices( + builder, op, *coords, {laneId, logicalValueId}, newIndices); + + Value el = builder.create(loc, loadedElType, + op.getSource(), newIndices); + result = builder.create(loc, el, result, + builder.getI64ArrayAttr(i)); + } + } else if (isTransposeLoad && targetLayout == NVVM::MMALayout::col) { + if (auto vecType = loadedElType.dyn_cast()) { + loadedElType = vecType.getElementType(); + } + // Load each element individually. + for (int i = 0; i < vectorType.getShape()[0]; i++) { + for (unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1]; + innerIdx++) { + + Value logicalValueId = builder.create( + loc, builder.getIndexType(), + builder.getIndexAttr(i * regInfo->elementsPerRegister + innerIdx)); + FailureOr coords = nvgpu::getLaneIdAndValueIdToOperandCoord( + op.getLoc(), builder, *warpMatrixInfo); + if (failed(coords)) + return failure(); + + SmallVector newIndices; + getXferIndices( + builder, op, *coords, {laneId, logicalValueId}, newIndices); + Value el = builder.create(op.getLoc(), loadedElType, + op.getSource(), newIndices); + result = builder.create( + op.getLoc(), el, result, builder.getI64ArrayAttr({i, innerIdx})); + } + } + } else { + return failure(); + } + + valueMapping[op.getResult()] = result; + return success(); +} + +/// Converts a `vector.transfer_read` operation directly to either a +/// `vector.load` or a `nvgpu.ldmatrix` operation. This function should only be +/// used when converting to `nvgpu.mma.sync` operations. +static LogicalResult +convertTransferReadToLoads(vector::TransferReadOp op, + llvm::DenseMap &valueMapping) { + OpBuilder b(op); + + FailureOr warpMatrixInfo = + nvgpu::getWarpMatrixInfo(op); + if (failed(warpMatrixInfo)) + return failure(); + + bool isLdMatrixCompatible = + op.getSource().getType().cast().getMemorySpaceAsInt() == 3 && + nvgpu::inferTileWidthInBits(*warpMatrixInfo) == 128; + + VectorType vecTy = op.getVectorType(); + int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth(); + + // When we are transposing the B operand, ldmatrix will only work if we have + // at least 8 rows to read and the width to read for the transpose is 128 + // bits. + if (!op.getPermutationMap().isMinorIdentity() && + (vecTy.getDimSize(1) < 8 || vecTy.getDimSize(0) * bitWidth < 128)) + isLdMatrixCompatible = false; + + if (!isLdMatrixCompatible) + return createNonLdMatrixLoads(op, b, valueMapping); + + return creatLdMatrixCompatibleLoads(op, b, valueMapping); +} + +static LogicalResult +convertTransferWriteToStores(vector::TransferWriteOp op, + llvm::DenseMap &valueMapping) { + OpBuilder b(op); + Location loc = op->getLoc(); + Value matrix = valueMapping.find(op.getVector())->second; + + FailureOr warpMatrixInfo = + nvgpu::getWarpMatrixInfo(op); + if (failed(warpMatrixInfo)) + return failure(); + FailureOr regInfo = + nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); + if (failed(regInfo)) + return failure(); + + VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); + Value laneId = b.create(loc); + + for (unsigned i = 0; i < vectorType.getShape()[0]; i++) { + Value logicalValueId = b.create( + loc, b.getIndexType(), + b.getIndexAttr(i * regInfo->elementsPerRegister)); + FailureOr coords = nvgpu::getLaneIdAndValueIdToOperandCoord( + op.getLoc(), b, *warpMatrixInfo); + if (failed(coords)) + return failure(); + + Value el = b.create(loc, matrix, ArrayRef{i}); + SmallVector newIndices; + getXferIndices( + b, op, *coords, {laneId, logicalValueId}, newIndices); + b.create(loc, el, op.getSource(), newIndices); + } + op->erase(); + return success(); +} + static void convertContractOp(vector::ContractionOp op, llvm::DenseMap &valueMapping) { OpBuilder b(op); @@ -397,6 +682,22 @@ static void convertContractOp(vector::ContractionOp op, valueMapping[op.getResult()] = matmul; } +static LogicalResult +convertContractOpToMmaSync(vector::ContractionOp op, + llvm::DenseMap &valueMapping) { + OpBuilder b(op); + Value opA = valueMapping.find(op.getLhs())->second; + Value opB = valueMapping.find(op.getRhs())->second; + Value opC = valueMapping.find(op.getAcc())->second; + int64_t m = op.getLhs().getType().cast().getShape()[0]; + int64_t n = op.getRhs().getType().cast().getShape()[0]; + int64_t k = op.getLhs().getType().cast().getShape()[1]; + Value matmul = b.create( + op.getLoc(), opC.getType(), opA, opB, opC, b.getI64ArrayAttr({m, n, k})); + valueMapping[op.getResult()] = matmul; + return success(); +} + /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op. static void convertConstantOp(arith::ConstantOp op, llvm::DenseMap &valueMapping) { @@ -509,13 +810,20 @@ static void convertElementwiseOp(Operation *op, gpu::MMAElementwiseOp opType, valueMapping[op->getResult(0)] = newOp; } -void mlir::populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns) { - patterns.add( - patterns.getContext()); +void mlir::populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns, + bool useNvGpu) { + if (!useNvGpu) { + patterns.add( + patterns.getContext()); + return; + } + patterns + .add( + patterns.getContext()); } void mlir::convertVectorToMMAOps(Operation *rootOp) { - SetVector ops = getOpToConvert(rootOp); + SetVector ops = getOpToConvert(rootOp, /*useNvGpu=*/false); llvm::DenseMap valueMapping; for (Operation *op : ops) { if (auto transferRead = dyn_cast(op)) { @@ -538,21 +846,71 @@ void mlir::convertVectorToMMAOps(Operation *rootOp) { } } +LogicalResult mlir::convertVectorToNVVMCompatibleMMASync(Operation *rootOp) { + SetVector ops = getOpToConvert(rootOp, /*useNvGpu=*/true); + llvm::DenseMap valueMapping; + for (Operation *op : ops) { + if (llvm::TypeSwitch(op) + .Case([&](vector::TransferReadOp transferReadOp) { + return convertTransferReadToLoads(transferReadOp, valueMapping); + }) + .Case([&](vector::TransferWriteOp transferWriteOp) { + return convertTransferWriteToStores(transferWriteOp, + valueMapping); + }) + .Case([&](vector::ContractionOp contractionOp) { + return convertContractOpToMmaSync(contractionOp, valueMapping); + }) + .Case([&](scf::ForOp forOp) { + convertForOp(forOp, valueMapping); + return success(); + }) + .Case([&](scf::YieldOp yieldOp) { + convertYieldOp(yieldOp, valueMapping); + return success(); + }) + .Case([&](arith::ConstantOp constOp) { + return convertConstantOpMmaSync(constOp, valueMapping); + }) + .Default([&](Operation *op) { + op->emitError() << "unhandled vector to mma type: " << *op; + return failure(); + }) + .failed()) { + op->emitError() << "Failed to convert op " << *op; + return failure(); + } + } + return success(); +} + namespace { struct ConvertVectorToGPUPass : public ConvertVectorToGPUBase { + + explicit ConvertVectorToGPUPass(bool useNvGpu_) { + useNvGpu.setValue(useNvGpu_); + } + void runOnOperation() override { RewritePatternSet patterns(&getContext()); - populatePrepareVectorToMMAPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + populatePrepareVectorToMMAPatterns(patterns, useNvGpu.getValue()); + if (failed( + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + return signalPassFailure(); - convertVectorToMMAOps(getOperation()); + if (useNvGpu.getValue()) { + if (failed(convertVectorToNVVMCompatibleMMASync(getOperation()))) + return signalPassFailure(); + } + + (void)convertVectorToMMAOps(getOperation()); } }; } // namespace -std::unique_ptr mlir::createConvertVectorToGPUPass() { - return std::make_unique(); +std::unique_ptr mlir::createConvertVectorToGPUPass(bool useNvGpu) { + return std::make_unique(useNvGpu); } diff --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir new file mode 100644 index 000000000000..be8d08be06ce --- /dev/null +++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir @@ -0,0 +1,349 @@ +// RUN: mlir-opt %s -split-input-file -pass-pipeline="func.func(convert-vector-to-gpu{use-nvgpu=true})" | FileCheck %s + +//######################################################### +// INT8 row-row-row +//######################################################### + +// CHECK-DAG: [[$rowA0_map:#.+]] = affine_map<()[s0] -> (s0 mod 16 + 1)> +// CHECK-DAG: [[$colA0_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 16 + 1)> + +// CHECK-DAG: [[$rowB0_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 39)> +// CHECK-DAG: [[$colB0_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 40)> +// CHECK-DAG: [[$rowB1_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 40)> +// CHECK-DAG: [[$rowB2_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 41)> +// CHECK-DAG: [[$rowB3_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 42)> +// CHECK-DAG: [[$rowB4_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 55)> +// CHECK-DAG: [[$rowB5_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 56)> +// CHECK-DAG: [[$rowB6_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 57)> +// CHECK-DAG: [[$rowB7_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 58)> + +// CHECK-DAG: [[$rowC0_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 49)> +// CHECK-DAG: [[$colC0_map:#.+]] = affine_map<()[s0] -> (s0 * 2 - (s0 floordiv 4) * 8 + 40)> +// CHECK-DAG: [[$rowC8_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 57)> + + +#map0 = affine_map<(d0, d1) -> (d1, d0)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-LABEL: func @m16n8k32_int8_row_row_row +func.func @m16n8k32_int8_row_row_row(%arg0: memref<128x128xi8, 3>, %arg1: memref<128x128xi8, 3>, %arg2: memref<128x128xi32>) { + %cst_0 = arith.constant dense<0> : vector<32x8xi8> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c17 = arith.constant 17 : index + %c39 = arith.constant 39 : index + %c40 = arith.constant 40 : index + %c49 = arith.constant 49 : index + %c50 = arith.constant 50 : index + %cst = arith.constant 0 : i8 + %cst0 = arith.constant 0 : i32 + + // Verify that the operand A is distributed to loads correctly. + + // CHECK: [[row:%.+]] = affine.apply [[$rowA0_map]]()[{{%.+}}] + // CHECK: [[col:%.+]] = affine.apply [[$colA0_map]]()[{{%.+}}] + // CHECK: nvgpu.ldmatrix %arg0[[[row]], [[col]]] {numTiles = 4 : i32, transpose = false} : memref<128x128xi8, 3> -> vector<4x4xi8> + + // Verify that the operand B is distributed to loads correctly. It's elements + // must be loaded in a non-vectorized manner to do the transpose. + + // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB0_map]]()[{{%.+}}] + // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB0_map]]()[{{%.+}}] + // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8, 3> + // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB1_map]]()[{{%.+}}] + // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB0_map]]()[{{%.+}}] + // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8, 3> + // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB2_map]]()[{{%.+}}] + // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB0_map]]()[{{%.+}}] + // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8, 3> + // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB3_map]]()[{{%.+}}] + // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB0_map]]()[{{%.+}}] + // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8, 3> + // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB0_map]]()[{{%.+}}] + // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB4_map]]()[{{%.+}}] + // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8, 3> + // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB5_map]]()[{{%.+}}] + // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB0_map]]()[{{%.+}}] + // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8, 3> + // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB6_map]]()[{{%.+}}] + // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB0_map]]()[{{%.+}}] + // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8, 3> + // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB7_map]]()[{{%.+}}] + // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB0_map]]()[{{%.+}}] + // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8, 3> + // CHECK-NOT: memref.load %arg1 + + // Verify that the operand C is distributed to loads correctly. + // CHECK: [[row:%.+]] = affine.apply [[$rowC0_map]]()[{{%.+}}] + // CHECK: [[col:%.+]] = affine.apply [[$colC0_map]]()[{{%.+}}] + // CHECK: vector.load %arg2[[[row]], [[col]]] : memref<128x128xi32>, vector<2xi32> + // CHECK: [[row:%.+]] = affine.apply [[$rowC8_map]]()[{{%.+}}] + // CHECK: [[col:%.+]] = affine.apply [[$colC0_map]]()[{{%.+}}] + // CHECK: vector.load %arg2[[[row]], [[col]]] : memref<128x128xi32>, vector<2xi32> + // CHECK-NOT: vector.load %arg2{{.*}} + + %A = vector.transfer_read %arg0[%c1, %c1], %cst {in_bounds = [true, true]} : memref<128x128xi8, 3>, vector<16x32xi8> + %B = vector.transfer_read %arg1[%c39, %c40], %cst {in_bounds = [true, true], permutation_map = #map0} : memref<128x128xi8, 3>, vector<8x32xi8> + %C = vector.transfer_read %arg2[%c49, %c40], %cst0 {in_bounds = [true, true]} : memref<128x128xi32>, vector<16x8xi32> + // CHECK: [[d:%.+]] = nvgpu.mma.sync({{.*}}) {mmaShape = [16, 8, 32]} : (vector<4x4xi8>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<16x32xi8>, vector<8x32xi8> into vector<16x8xi32> + + // CHECK: [[row:%.+]] = affine.apply [[$rowC0_map]]()[{{%.+}}] + // CHECK: [[col:%.+]] = affine.apply [[$colC0_map]]()[{{%.+}}] + // CHECK: vector.store {{%.+}}, %arg2[[[row]], [[col]]] : memref<128x128xi32>, vector<2xi32> + // CHECK: [[row:%.+]] = affine.apply [[$rowC8_map]]()[{{%.+}}] + // CHECK: [[col:%.+]] = affine.apply [[$colC0_map]]()[{{%.+}}] + // CHECK: vector.store {{%.+}}, %arg2[[[row]], [[col]]] : memref<128x128xi32>, vector<2xi32> + vector.transfer_write %D, %arg2[%c49, %c40] {in_bounds = [true, true]} : vector<16x8xi32>, memref<128x128xi32> + return +} + +// ----- + +//######################################################### +// f64 row-row-row +//######################################################### +// CHECK-DAG: [[$rowA0_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 1)> +// CHECK-DAG: [[$colA0_map:#.+]] = affine_map<()[s0] -> (s0 mod 4 + 1)> + +// CHECK-DAG: [[$rowb0_map:#.+]] = affine_map<()[s0] -> (s0 mod 4 + 39)> +// CHECK-DAG: [[$colb0_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 40)> + +// CHECK-DAG: [[$rowC0_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 49)> +// CHECK-DAG: [[$colC0_map:#.+]] = affine_map<()[s0] -> (s0 * 2 - (s0 floordiv 4) * 8 + 40) + +#map0 = affine_map<(d0, d1) -> (d1, d0)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-LABEL: func @m8n8k4_f64_row_row_row +func.func @m8n8k4_f64_row_row_row(%arg0: memref<128x128xf64>, %arg1: memref<128x128xf64>, %arg2: memref<128x128xf64>) { + %cst_0 = arith.constant dense<0.0> : vector<4x8xf64> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c17 = arith.constant 17 : index + %c39 = arith.constant 39 : index + %c40 = arith.constant 40 : index + %c49 = arith.constant 49 : index + %c50 = arith.constant 50 : index + %cst = arith.constant 0.0 : f64 + %cst0 = arith.constant 0.0 : f64 + + // Verify that the operand A is distributed to loads correctly. + + // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA0_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA0_map]] + // CHECK: vector.load %arg0[[[row]], [[col]]] : memref<128x128xf64>, vector<1xf64> + + // Verify that the operand B is distributed to loads correctly. It's elements + // must be loaded in a non-vectorized manner to do the transpose. + + // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowb0_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[$colb0_map]] + // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xf64> + + // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowC0_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[$colC0_map]] + // CHECK: vector.load %arg2[[[row]], [[col]]] : memref<128x128xf64>, vector<2xf64> + + %A = vector.transfer_read %arg0[%c1, %c1], %cst {in_bounds = [true, true]} : memref<128x128xf64>, vector<8x4xf64> + %B = vector.transfer_read %arg1[%c39, %c40], %cst {in_bounds = [true, true], permutation_map = #map0} : memref<128x128xf64>, vector<8x4xf64> + %C = vector.transfer_read %arg2[%c49, %c40], %cst0 {in_bounds = [true, true]} : memref<128x128xf64>, vector<8x8xf64> + // CHECK: [[d:%.+]] = nvgpu.mma.sync({{.*}}) {mmaShape = [8, 8, 4]} : (vector<1x1xf64>, vector<1x1xf64>, vector<1x2xf64>) -> vector<1x2xf64> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<8x4xf64>, vector<8x4xf64> into vector<8x8xf64> + + // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowC0_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[$colC0_map]] + // CHECK: vector.store {{%.+}}, %arg2[[[row]], [[col]]] : memref<128x128xf64>, vector<2xf64> + vector.transfer_write %D, %arg2[%c49, %c40] {in_bounds = [true, true]} : vector<8x8xf64>, memref<128x128xf64> + return +} + +// ----- + +//######################################################### +// FP16 row-row-row +//######################################################### + +#map0 = affine_map<(d0, d1) -> (d1, d0)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-DAG: [[$rowA_map:#.+]] = affine_map<()[s0] -> (s0 mod 16 + 1)> +// CHECK-DAG: [[$colA_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8 + 3)> + +// CHECK-DAG: [[$rowB_map:#.+]] = affine_map<()[s0] -> (s0 + 3)> +// CHECK-DAG: [[$colB_map:#.+]] = affine_map<() -> (3)> + +// CHECK-LABEL: func @m16n8k16_fp16_row_row_row +func.func @m16n8k16_fp16_row_row_row(%arg0: memref<20x20xf16, 3>, %arg1: memref<20x20xf16, 3>, %arg2: memref<20x20xf16, 3>) { + %cst_0 = arith.constant dense<0.000000e+00> : vector<16x8xf16> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %cst = arith.constant 0.000000e+00 : f16 + // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA_map]] + // CHECK: nvgpu.ldmatrix %arg0[[[row]], [[col]]] {numTiles = 4 : i32, transpose = false} + + // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB_map]] + // CHECK: nvgpu.ldmatrix %arg1[[[row]], [[col]]] {numTiles = 2 : i32, transpose = true} + %A = vector.transfer_read %arg0[%c1, %c3], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<16x16xf16> + %B = vector.transfer_read %arg1[%c3, %c3], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<8x16xf16> + %C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<16x8xf16> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16> + vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf16>, memref<20x20xf16, 3> + return +} + +// ----- + +// CHECK-DAG: [[$Arow_map:#.+]] = affine_map<()[s0] -> (s0 mod 16 + 1)> +// CHECK-DAG: [[$Acol_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8 + 3)> +// CHECK-DAG: [[$Bcol_map:#.+]] = affine_map<() -> (3)> +// CHECK-DAG: [[$Brow_map:#.+]] = affine_map<()[s0] -> (s0 + 3)> + +#map0 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-LABEL: func @batch_m16n8k16_fp16_row_row_row +func.func @batch_m16n8k16_fp16_row_row_row(%arg0: memref<2x20x20xf16, 3>, %arg1: memref<2x20x20xf16, 3>, %arg2: memref<2x20x20xf16, 3>) { + %cst_0 = arith.constant dense<0.000000e+00> : vector<20x20xf16> + // CHECK: [[C0:%.+]] = arith.constant 0 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %cst = arith.constant 0.000000e+00 : f16 + + // CHECK-DAG: [[row:%.+]] = affine.apply [[$Arow_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[$Acol_map]] + // CHECK: nvgpu.ldmatrix %arg0[[[C0]], [[row]], [[col]]] {numTiles = 4 : i32, transpose = false} : memref<2x20x20xf16, 3> -> vector<4x2xf16> + %A = vector.transfer_read %arg0[%c0, %c1, %c3], %cst {in_bounds = [true, true]} : memref<2x20x20xf16, 3>, vector<16x16xf16> + + // CHECK-DAG: [[row:%.+]] = affine.apply [[$Brow_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[$Bcol_map]] + // CHECK: nvgpu.ldmatrix %arg1[[[C0]], [[row]], [[col]]] {numTiles = 2 : i32, transpose = true} : memref<2x20x20xf16, 3> -> vector<2x2xf16> + %B = vector.transfer_read %arg1[%c0, %c3, %c3], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<2x20x20xf16, 3>, vector<8x16xf16> + + // CHECK-DAG: [[row:%.+]] = affine.apply [[$Arow_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[$Acol_map]] + // CHECK: nvgpu.ldmatrix %arg2[[[C0]], [[row]], [[col]]] {numTiles = 2 : i32, transpose = false} : memref<2x20x20xf16, 3> -> vector<2x2xf16> + %C = vector.transfer_read %arg2[%c0, %c1, %c3], %cst {in_bounds = [true, true]} : memref<2x20x20xf16, 3>, vector<16x8xf16> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16> + vector.transfer_write %D, %arg2[%c0, %c1, %c3] {in_bounds = [true, true]} : vector<16x8xf16>, memref<2x20x20xf16, 3> + return +} + +// ----- + +//######################################################### +// FP16 row-col-row +//######################################################### + +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK: [[$rowA_map:#.+]] = affine_map<()[s0] -> (s0 mod 16 + 1)> +// CHECK: [[$colA_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8 + 3)> + +// CHECK: [[$rowB_map:#.+]] = affine_map<()[s0] -> (s0 mod 8 + 1)> +// CHECK: [[$colB_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 8) * 8 + 3)> + +// CHECK-LABEL: func @m16n8k16_fp16_row_col_row +func.func @m16n8k16_fp16_row_col_row(%arg0: memref<20x20xf16, 3>, %arg1: memref<20x20xf16, 3>, %arg2: memref<20x20xf16, 3>) { + %cst_0 = arith.constant dense<0.000000e+00> : vector<16x8xf16> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %cst = arith.constant 0.000000e+00 : f16 + // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA_map]] + // CHECK: nvgpu.ldmatrix %arg0[[[row]], [[col]]] {numTiles = 4 : i32 + // CHECK-SAME: transpose = false + + // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB_map]] + // CHECK: nvgpu.ldmatrix %arg1[[[row]], [[col]]] {numTiles = 2 : i32 + // CHECK-SAME: transpose = false + + // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA_map]] + // CHECK: nvgpu.ldmatrix %arg2[[[row]], [[col]]] {numTiles = 2 : i32 + // CHECK-SAME: transpose = false + %A = vector.transfer_read %arg0[%c1, %c3], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<16x16xf16> + %B = vector.transfer_read %arg1[%c1, %c3], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<8x16xf16> + %C = vector.transfer_read %arg2[%c1, %c3], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<16x8xf16> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16> + vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf16>, memref<20x20xf16, 3> + return +} + +// ----- + +//######################################################### +// TF32 (multiplicand) F32 (accumulator) row-row-row +//######################################################### + +#map0 = affine_map<(d0, d1) -> (d1, d0)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-DAG: [[$rowA_map:#.+]] = affine_map<()[s0] -> (s0 mod 16 + 1)> +// CHECK-DAG: [[$colA_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 4 + 3)> + +// CHECK-DAG: [[$rowB_map:#.+]] = affine_map<()[s0] -> (s0 mod 4 + 3)> +// CHECK-DAG: [[$colB_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 3)> + +// CHECK-DAG: [[$rowC_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4)> +// CHECK-DAG: [[$rowC8_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 8)> +// CHECK-DAG: [[$colC_map:#.+]] = affine_map<()[s0] -> (s0 * 2 - (s0 floordiv 4) * 8)> + +// CHECK-LABEL: func @m16n8k4_tf32_f32_row_row_row +func.func @m16n8k4_tf32_f32_row_row_row(%arg0: memref<20x20xf32, 3>, %arg1: memref<20x20xf32, 3>, %arg2: memref<20x20xf32>) { + %cst_0 = arith.constant dense<0.000000e+00> : vector<16x8xf32> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %cst = arith.constant 0.000000e+00 : f32 + + // CHECK: [[c_frag:%.+]] = arith.constant {{.*}} : vector<2x2xf32> + + // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA_map]] + // CHECK: [[a_frag:%.+]] = nvgpu.ldmatrix %arg0[[[row]], [[col]]] {numTiles = 2 : i32, transpose = false} + + // b and c are not loaded by ldmatrix in this test. + // CHECK-NOT: nvgpu.ldmatrix + + // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB_map]] + // CHECK: [[b_el:%.+]] = memref.load {{%.+}} : memref<20x20xf32, 3> + // CHECK: [[b_frag:%.+]] = vector.insert [[b_el]], {{.*}} : f32 into vector<1x1xf32> + + // CHECK: [[d_frag:%.+]] = nvgpu.mma.sync([[a_frag]], [[b_frag]], [[c_frag]]) + // CHECK-SAME: mmaShape = [16, 8, 4] + // CHECK-SAME: -> vector<2x2xf32> + %A = vector.transfer_read %arg0[%c1, %c3], %cst {in_bounds = [true, true]} : memref<20x20xf32, 3>, vector<16x4xf32> + %B = vector.transfer_read %arg1[%c3, %c3], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<20x20xf32, 3>, vector<8x4xf32> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %cst_0 : vector<16x4xf32>, vector<8x4xf32> into vector<16x8xf32> + + // CHECK: vector.extract [[d_frag]][0] : vector<2x2xf32> + // CHECK: affine.apply [[$rowC_map]] + // CHECK: affine.apply [[$colC_map]] + // CHECK: vector.store + // CHECK: vector.extract [[d_frag]][1] : vector<2x2xf32> + // CHECK: affine.apply [[$rowC8_map]] + // CHECK: affine.apply [[$colC_map]] + // CHECK: vector.store + vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf32>, memref<20x20xf32> + return +}