diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 41e7b29f15d0..6d9863e72348 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -851,8 +851,13 @@ def ConvertVectorToGPU : Pass<"convert-vector-to-gpu"> { "dialect"; let constructor = "mlir::createConvertVectorToGPUPass()"; let dependentDialects = [ - "memref::MemRefDialect", - "gpu::GPUDialect" + "memref::MemRefDialect", "gpu::GPUDialect", "AffineDialect", + "vector::VectorDialect", "nvgpu::NVGPUDialect" + ]; + + let options = [ + Option<"useNvGpu", "use-nvgpu", "bool", /*default=*/"false", + "convert to NvGPU ops instead of GPU dialect ops"> ]; } diff --git a/mlir/include/mlir/Conversion/VectorToGPU/VectorToGPU.h b/mlir/include/mlir/Conversion/VectorToGPU/VectorToGPU.h index 266fa0eac4c4..1ba5b3f90d9a 100644 --- a/mlir/include/mlir/Conversion/VectorToGPU/VectorToGPU.h +++ b/mlir/include/mlir/Conversion/VectorToGPU/VectorToGPU.h @@ -17,16 +17,25 @@ class Pass; class RewritePatternSet; /// Patterns to transform vector ops into a canonical form to convert to MMA -/// matrix operations. -void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns); +/// matrix operations. If `useNvGpu` is true, then the patterns will populated +/// will prepare for conversion to `nvgpu` mma operations rather than the `gpu` +/// dialect WMMA operations. +void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns, + bool useNvGpu = false); /// Convert vector ops to MMA matrix operations nested under `rootOp`. This will /// convert slice of operations that can be legally converted to MMA operations. /// The rest of the vector operations are left untouched. void convertVectorToMMAOps(Operation *rootOp); +/// Convert vector ops ops nested under `rootOp` to vector and GPU operaitons +/// compatible with the `nvvm.mma.sync` lowering path. This will convert a slice +/// of operations that can be legally lowered on this path while the rest of +/// the vector operations are left untouched. +LogicalResult convertVectorToNVVMCompatibleMMASync(Operation *rootOp); + /// Convert from vector to GPU ops. -std::unique_ptr createConvertVectorToGPUPass(); +std::unique_ptr createConvertVectorToGPUPass(bool useNvGpu = false); } // namespace mlir diff --git a/mlir/lib/Conversion/PassDetail.h b/mlir/lib/Conversion/PassDetail.h index e05004061dd4..530e156024fd 100644 --- a/mlir/lib/Conversion/PassDetail.h +++ b/mlir/lib/Conversion/PassDetail.h @@ -55,6 +55,10 @@ namespace LLVM { class LLVMDialect; } // namespace LLVM +namespace nvgpu { +class NVGPUDialect; +} + namespace NVVM { class NVVMDialect; } // namespace NVVM diff --git a/mlir/lib/Conversion/VectorToGPU/CMakeLists.txt b/mlir/lib/Conversion/VectorToGPU/CMakeLists.txt index 06758c5fe126..778f2c42eebe 100644 --- a/mlir/lib/Conversion/VectorToGPU/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToGPU/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_conversion_library(MLIRVectorToGPU VectorToGPU.cpp + NvGpuSupport.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToGPU diff --git a/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.cpp b/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.cpp new file mode 100644 index 000000000000..a2820c3e88f8 --- /dev/null +++ b/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.cpp @@ -0,0 +1,327 @@ +//===- NvGpuSupport.cpp - MLIR Vector to GPU lowering support --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file provides utilities to assist in the lowering of Vector operations +// to NvGPU dialect MMA operations. +// +//===----------------------------------------------------------------------===// + +#include "NvGpuSupport.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/NVGPU/NVGPUDialect.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" + +namespace mlir { +namespace nvgpu { +namespace { + +/// There are always 4 threads per [128|256|512] bit row. +constexpr int64_t kThreadsPerRow = 4; + +constexpr int64_t kNumRowsPerTile = 8; + +bool isAccumulatorOrResult(MatMulOperandRole operandType) { + return operandType == MatMulOperandRole::C; +} + +/// Returns the number of registers which compose a matrix fragment held by a +/// single thread. +int64_t inferNumRegistersPerMatrixFragment(const WarpMatrixInfo &type) { + int64_t lineSize = inferTileWidthInBits(type); + auto shape = type.vectorType.getShape(); + return (shape[0] / kNumRowsPerTile) * + (shape[1] * type.vectorType.getElementType().getIntOrFloatBitWidth()) / + lineSize; +} + +/// Returns the number of 8 x [128|256|512] bit tiles that compose the given +/// operand shape. +std::array getTileShape(ArrayRef operandShape, + Type elementType, int64_t lineSizeBits) { + // For each 8x128bit square, a thread is responsible for one 32bit register. + return {operandShape[0] / kNumRowsPerTile, + (operandShape[1] * elementType.getIntOrFloatBitWidth()) / + lineSizeBits}; +} + +} // namespace + +FailureOr getWarpMatrixInfo(Operation *op) { + WarpMatrixInfo info; + + // Determine the vector type. + if (vector::TransferWriteOp writeOp = dyn_cast(op)) { + info.vectorType = writeOp.getVectorType(); + } else if (isa(op)) { + info.vectorType = op->getResult(0).getType().cast(); + } else { + return op->emitError() + << "unhandled operation type in nvgpu.mma.sync conversion path"; + } + + // Determine the operand role. We assume it is an accumulator/result unless it + // is directly consumed by a `vector.contract` op. + info.operandRole = MatMulOperandRole::C; + for (Operation *user : op->getUsers()) { + auto contract = dyn_cast(user); + if (!contract) + continue; + if (contract.getLhs() == op->getResult(0)) { + info.operandRole = MatMulOperandRole::A; + break; + } + if (contract.getRhs() == op->getResult(0)) { + info.operandRole = MatMulOperandRole::B; + break; + } + } + return info; +} + +int64_t inferTileWidthInBits(const WarpMatrixInfo &type) { + bool isAcc = isAccumulatorOrResult(type.operandRole); + Type elType = type.vectorType.getElementType(); + if (isAcc && elType.getIntOrFloatBitWidth() == 32) { + return 256; + } + if (elType.getIntOrFloatBitWidth() == 64) { + return isAcc ? 512 : 256; + } + return 128; +} + +FailureOr +getMmaSyncRegisterType(const WarpMatrixInfo &type) { + MLIRContext *ctx = type.vectorType.getContext(); + const bool isAccum = isAccumulatorOrResult(type.operandRole); + + Type elType = type.vectorType.getElementType(); + if (elType.isF16()) { + return FragmentElementInfo{ + LLVM::getFixedVectorType(Float16Type::get(ctx), 2), 2, 32, + inferNumRegistersPerMatrixFragment(type)}; + } + + // f64 operand + Type f64Ty = Float64Type::get(ctx); + if (elType.isF64()) { + return isAccum + ? FragmentElementInfo{LLVM::getFixedVectorType(f64Ty, 2), 2, 128, + inferNumRegistersPerMatrixFragment(type)} + : FragmentElementInfo{f64Ty, 1, 64, + inferNumRegistersPerMatrixFragment(type)}; + } + + // int8 operand + if (elType.isInteger(8)) { + return FragmentElementInfo{ + LLVM::getFixedVectorType(IntegerType::get(ctx, 8), 4), 4, 32, + inferNumRegistersPerMatrixFragment(type)}; + } + // Integer 32bit acc operands + if (elType.isInteger(32)) { + return FragmentElementInfo{ + LLVM::getFixedVectorType(IntegerType::get(ctx, 32), 2), 2, 64, + inferNumRegistersPerMatrixFragment(type)}; + } + + // Floating point 32bit operands + if (elType.isF32()) { + Type f32Ty = Float32Type::get(ctx); + return isAccum + ? FragmentElementInfo{LLVM::getFixedVectorType(f32Ty, 2), 2, 64, + inferNumRegistersPerMatrixFragment(type)} + : FragmentElementInfo{f32Ty, 1, 32, + inferNumRegistersPerMatrixFragment(type)}; + } + return failure(); +} + +static AffineMap getRegisterIndexToTileOffsetMap(int64_t lineSize, + Type elementType, + ArrayRef operandShape, + bool isAccumulator, + int64_t elementsPerRegister, + AffineExpr logicalValueId) { + const int64_t elementsPerLine = + lineSize / elementType.getIntOrFloatBitWidth(); + const std::array num8x128bTiles = + getTileShape(operandShape, elementType, lineSize); + AffineExpr registerIdx = logicalValueId.floorDiv(elementsPerRegister); + return AffineMap::get( + 2, 0, + {(registerIdx % num8x128bTiles[0]) * 8, + (registerIdx.floorDiv(num8x128bTiles[0])) * elementsPerLine}, + elementType.getContext()); +} + +FailureOr +getLaneIdAndValueIdToOperandCoord(Location loc, OpBuilder &builder, + const WarpMatrixInfo &fragmentType) { + Type elementType = fragmentType.vectorType.getElementType(); + ArrayRef operandShape = fragmentType.vectorType.getShape(); + FailureOr regInfo = + getMmaSyncRegisterType(fragmentType); + if (failed(regInfo)) + return failure(); + + const int64_t elementBitWidth = elementType.getIntOrFloatBitWidth(); + const int64_t elementsPerRegister = + regInfo->registerWidthBits / elementBitWidth; + const int64_t lineSize = inferTileWidthInBits(fragmentType); + + AffineExpr laneId, logicalValueIdDim; + bindDims(builder.getContext(), laneId, logicalValueIdDim); + + // Determine what register logicalValueId corresponds to. Use that as a + // linear index into the coordinate mapping `index -> (tile row, tile col)`. + AffineMap registerIndexToTileCoord = getRegisterIndexToTileOffsetMap( + lineSize, elementType, operandShape, + isAccumulatorOrResult(fragmentType.operandRole), elementsPerRegister, + logicalValueIdDim); + + auto makeMap = [&](ArrayRef dimExprs) -> AffineMap { + return AffineMap::get(2, 0, dimExprs, builder.getContext()); + }; + + auto tileRow = registerIndexToTileCoord.getResult(0); + auto tileCol = registerIndexToTileCoord.getResult(1); + return makeMap({tileRow + laneId.floorDiv(kThreadsPerRow), + tileCol + (laneId % kThreadsPerRow) * elementsPerRegister + + (logicalValueIdDim % elementsPerRegister)}); +} + +FailureOr getLdMatrixParams(const WarpMatrixInfo &type, + bool transpose) { + LdMatrixParams params; + Type elType = type.vectorType.getElementType(); + params.fragmentType = type.vectorType; + if (type.operandRole == MatMulOperandRole::A || + type.operandRole == MatMulOperandRole::C) { + params.targetLayout = NVVM::MMALayout::row; + } else { + params.targetLayout = NVVM::MMALayout::col; + } + ArrayRef shape = type.vectorType.getShape(); + params.contiguousDimType = + transpose ? IteratorType::Parallel : IteratorType::Reduction; + + if (params.targetLayout == NVVM::MMALayout::row) { + params.numTiles = (shape[0] / kNumRowsPerTile) * + ((shape[1] * elType.getIntOrFloatBitWidth()) / 128); + } else { + params.numTiles = (shape[1] / kNumRowsPerTile) * + ((shape[0] * elType.getIntOrFloatBitWidth()) / 128); + } + + if (params.numTiles == 0) + return failure(); + + return params; +} + +FailureOr +getLaneIdToLdMatrixMatrixCoord(Location loc, OpBuilder &builder, + const LdMatrixParams ¶ms) { + // One thread per 128b row. + const int64_t kNumThreadsPerTile = kNumRowsPerTile; + const int bitsPerElement = static_cast( + params.fragmentType.getElementType().getIntOrFloatBitWidth()); + const int kElementsPer128b = (128 / bitsPerElement); + ArrayRef operandShape = params.fragmentType.getShape(); + AffineExpr d0 = getAffineDimExpr(0, builder.getContext()); + + auto makeMap = [&](ArrayRef dimExprs) -> AffineMap { + return AffineMap::get(1, 0, dimExprs, builder.getContext()); + }; + + // This case corresponds to row-major A|C or col-major B operands. + if (params.contiguousDimType == IteratorType::Reduction) { + AffineExpr row = d0 % (operandShape[0]); + AffineExpr col = d0.floorDiv(operandShape[0]) * (kElementsPer128b); + return makeMap({row, col}); + } + + // This case Corresponds to col-major A|C or row-major B operands. The + // operandShape given is already pre-transposed (e.g. 8x16 = KxN). + if (params.contiguousDimType == IteratorType::Parallel) { + const int64_t num8x128bCols = (operandShape[0] * bitsPerElement) / 128; + // Threads are assigned in groups of 8 first across columns, then to + // rows. This is transpose of what `ldmatrix` expects, but when + // `ldmatrix` gets the `.trans` qualifier, final the effect will be to + // transpose just the blocks. + auto groupIdx = d0.floorDiv(kNumThreadsPerTile); + auto tileCol = (groupIdx % num8x128bCols); + auto tileRow = groupIdx.floorDiv(num8x128bCols); + return makeMap({tileCol * kElementsPer128b, + tileRow * kNumRowsPerTile + (d0 % kNumRowsPerTile)}); + } + return failure(); +} + +LogicalResult +PrepareContractToGPUMMASync::matchAndRewrite(vector::ContractionOp op, + PatternRewriter &rewriter) const { + Location loc = op.getLoc(); + Value lhs = op.getLhs(); + Value rhs = op.getRhs(); + Value res = op.getAcc(); + + // Set up the parallel/reduction structure in right form. + using MapList = ArrayRef>; + auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; + AffineExpr m; + AffineExpr n; + AffineExpr k; + bindDims(rewriter.getContext(), m, n, k); + static constexpr std::array perm = {1, 0}; + auto iteratorTypes = op.getIteratorTypes().getValue(); + SmallVector maps = op.getIndexingMaps(); + if (iteratorTypes.size() != 3) + return failure(); + if (!(isParallelIterator(iteratorTypes[0]) && + isParallelIterator(iteratorTypes[1]) && + isReductionIterator(iteratorTypes[2]))) + return failure(); + + // The canonical form is "TNT" = A row-major, B col-major, C row-major. + const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}}); + if (maps == canonicalForm) { + return failure(); + } + if (maps == infer({{m, k}, {k, n}, {m, n}})) { + rhs = rewriter.create(loc, rhs, perm); + } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { + lhs = rewriter.create(loc, lhs, perm); + } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { + rhs = rewriter.create(loc, rhs, perm); + lhs = rewriter.create(loc, lhs, perm); + } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { + std::swap(rhs, lhs); + rhs = rewriter.create(loc, rhs, perm); + lhs = rewriter.create(loc, lhs, perm); + } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { + std::swap(rhs, lhs); + rhs = rewriter.create(loc, rhs, perm); + } else if (maps == infer({{m, k}, {k, n}, {n, m}})) { + std::swap(lhs, rhs); + lhs = rewriter.create(loc, lhs, perm); + } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { + std::swap(lhs, rhs); + } else { + return failure(); + } + rewriter.replaceOpWithNewOp( + op, lhs, rhs, res, rewriter.getAffineMapArrayAttr(canonicalForm), + op.getIteratorTypes()); + return success(); +} + +} // namespace nvgpu +} // namespace mlir \ No newline at end of file diff --git a/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.h b/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.h new file mode 100644 index 000000000000..9902faa835a6 --- /dev/null +++ b/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.h @@ -0,0 +1,100 @@ +//===- NvvmMMASupport.h - MLIR Vector to GPU lowering support --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file provides utilities to assist in the lowering of Vector operations +// to GPU dialect MMA operations. +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_CONVERSION_VECTORTOGPU_NVGPUSUPPORT_H +#define MLIR_CONVERSION_VECTORTOGPU_NVGPUSUPPORT_H + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Types.h" + +namespace mlir { +namespace nvgpu { + +enum class MatMulOperandRole : int32_t { A = 0, B, C }; + +/// Collects information about a warp-level matrix operand represented by a +/// VectorType. +struct WarpMatrixInfo { + VectorType vectorType; + MatMulOperandRole operandRole; +}; + +/// Given an op that operates on a VectorType representing a warp-level matrix +/// operand, the function returns a struct containing relevant type information. +FailureOr getWarpMatrixInfo(Operation *op); + +/// Returns the number of bits in a single tile row. It is either 128, 256, or +/// 512 bits depending on the data type and` whether the operand is an +/// accumulator/result operand +int64_t inferTileWidthInBits(const WarpMatrixInfo &type); + +/// Specifies information about the registers which compose a matrix fragment +/// according to the PTX documentation. +struct FragmentElementInfo { + Type registerLLVMType; + int64_t elementsPerRegister; + int64_t registerWidthBits; + int64_t numRegistersPerFragment; +}; + +/// Returns a FragmentElementInfo struct describing the register types for the +/// given matrix fragment type. +FailureOr +getMmaSyncRegisterType(const WarpMatrixInfo &type); + +/// Returns an AffineMap which maps a two dimensions representing (laneId, +/// logicalValueId) and returns two results representing offsets within a +/// matrix operand. The offsets point to the values the thread is responsible +/// for (AKA the matrix fragment values) during a warp-collective matrix +/// operation. For a visual reference of this LaneId -> (row, col) mapping, +/// please see NVIDIA's PTX documentation: +/// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-mma +FailureOr +getLaneIdAndValueIdToOperandCoord(Location loc, OpBuilder &builder, + const WarpMatrixInfo &fragmentType); + +struct LdMatrixParams { + VectorType fragmentType; + bool isAccum; + int64_t numTiles; + IteratorType contiguousDimType; + NVVM::MMALayout targetLayout; +}; + +FailureOr getLdMatrixParams(const WarpMatrixInfo &type, + bool transpose); +/// Returns an AffineMap which maps a single dimension representing the laneId +/// to two results representing offsets within the matrix operand that should +/// be the pointer locations a thread should pass to the ldmatrix instruction. +FailureOr +getLaneIdToLdMatrixMatrixCoord(Location loc, OpBuilder &builder, + const LdMatrixParams ¶ms); + +// Transform contract into (m, k)x(n, k)x(m, n) form so that it can be converted +// to MMA matmul. +struct PrepareContractToGPUMMASync + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ContractionOp op, + PatternRewriter &rewriter) const override; +}; + +} // namespace nvgpu +} // namespace mlir + +#endif // MLIR_CONVERSION_VECTORTOGPU_NVGPUSUPPORT_H diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp index 9ed1c3483c11..a6e122c38031 100644 --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -12,6 +12,7 @@ #include +#include "NvGpuSupport.h" #include "mlir/Conversion/VectorToGPU/VectorToGPU.h" #include "../PassDetail.h" @@ -19,6 +20,7 @@ #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/NVGPU/NVGPUDialect.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" @@ -27,11 +29,39 @@ #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" +#include "llvm/ADT/TypeSwitch.h" using namespace mlir; +/// For a vector TransferOpType `xferOp`, an empty `indices` vector, and an +/// AffineMap representing offsets to apply to indices, the function fills +/// `indices` with the original indices plus the offsets. The offsets are +/// applied by taking into account the permutation map of the transfer op. If +/// the `offsetMap` has dimension placeholders, those should be provided in +/// `dimValues`. +template +static void getXferIndices(OpBuilder &b, TransferOpType xferOp, + AffineMap offsetMap, ArrayRef dimValues, + SmallVector &indices) { + indices.append(xferOp.getIndices().begin(), xferOp.getIndices().end()); + Location loc = xferOp.getLoc(); + unsigned offsetsIdx = 0; + for (auto expr : xferOp.getPermutationMap().getResults()) { + if (auto dim = expr.template dyn_cast()) { + Value prevIdx = indices[dim.getPosition()]; + SmallVector dims(dimValues.begin(), dimValues.end()); + dims.push_back(prevIdx); + AffineExpr d0 = b.getAffineDimExpr(offsetMap.getNumDims()); + indices[dim.getPosition()] = makeComposedAffineApply( + b, loc, d0 + offsetMap.getResult(offsetsIdx++), dims); + continue; + } + } +} + // Return true if the contract op can be convert to MMA matmul. -static bool contractSupportsMMAMatrixType(vector::ContractionOp contract) { +static bool contractSupportsMMAMatrixType(vector::ContractionOp contract, + bool useNvGpu) { if (llvm::size(contract.getMasks()) != 0) return false; @@ -47,7 +77,10 @@ static bool contractSupportsMMAMatrixType(vector::ContractionOp contract) { // The contract needs to represent a matmul to be able to convert to // MMAMatrix matmul. - if (contract.getIndexingMaps() != infer({{m, k}, {k, n}, {m, n}})) + if (!useNvGpu && + contract.getIndexingMaps() != infer({{m, k}, {k, n}, {m, n}})) + return false; + if (useNvGpu && contract.getIndexingMaps() != infer({{m, k}, {n, k}, {m, n}})) return false; return true; @@ -61,7 +94,7 @@ getMemrefConstantHorizontalStride(ShapedType type) { if (!memrefType) return false; // If the memref is 0 or 1D the horizontal stride is 0. - if(memrefType.getRank() < 2) + if (memrefType.getRank() < 2) return 0; int64_t offset = 0; SmallVector strides; @@ -75,7 +108,8 @@ getMemrefConstantHorizontalStride(ShapedType type) { } // Return true if the transfer op can be converted to a MMA matrix load. -static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) { +static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp, + bool useNvGpu) { if (readOp.getMask() || readOp.hasOutOfBoundsDim() || readOp.getVectorType().getRank() != 2) return false; @@ -87,9 +121,14 @@ static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) { AffineExpr zero = b.getAffineConstantExpr(0); auto broadcastInnerDim = AffineMap::get(map.getNumDims(), 0, {zero, innerDim}, readOp.getContext()); - // TODO: Support transpose once it is added to GPU dialect ops. - // For now we only support (d0, d1) -> (d0, d1) and (d0, d1) -> (0, d1). - return !(!map.isMinorIdentity() && map != broadcastInnerDim); + + if (!useNvGpu) { + // TODO: Support transpose once it is added to GPU dialect ops. + // For now we only support (d0, d1) -> (d0, d1) and (d0, d1) -> (0, d1). + return map.isMinorIdentity() || map == broadcastInnerDim; + } + + return true; } // Return true if the transfer op can be converted to a MMA matrix store. @@ -147,15 +186,15 @@ static bool elementwiseSupportsMMAMatrixType(Operation *op) { return convertElementwiseOpToMMA(op).hasValue(); } -static bool supportsMMaMatrixType(Operation *op) { +static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) { if (isa(op)) return true; if (auto transferRead = dyn_cast(op)) - return transferReadSupportsMMAMatrixType(transferRead); + return transferReadSupportsMMAMatrixType(transferRead, useNvGpu); if (auto transferWrite = dyn_cast(op)) return transferWriteSupportsMMAMatrixType(transferWrite); if (auto contract = dyn_cast(op)) - return contractSupportsMMAMatrixType(contract); + return contractSupportsMMAMatrixType(contract, useNvGpu); if (auto constant = dyn_cast(op)) return constantSupportsMMAMatrixType(constant); if (auto broadcast = dyn_cast(op)) @@ -203,7 +242,8 @@ static SetVector getSliceContract(Operation *op, // Analyze slice of operations based on convert op to figure out if the whole // slice can be converted to MMA operations. -static SetVector getOpToConvert(mlir::Operation *op) { +static SetVector getOpToConvert(mlir::Operation *op, + bool useNvGpu) { auto hasVectorDest = [](Operation *op) { return llvm::any_of(op->getResultTypes(), [](Type t) { return t.isa(); }); @@ -221,8 +261,9 @@ static SetVector getOpToConvert(mlir::Operation *op) { // If any instruction cannot use MMA matrix type drop the whole // chain. MMA matrix are stored in an opaque type so they cannot be used // by all operations. - if (llvm::any_of(dependentOps, - [](Operation *op) { return !supportsMMaMatrixType(op); })) + if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) { + return !supportsMMaMatrixType(op, useNvGpu); + })) return; opToConvert.insert(dependentOps.begin(), dependentOps.end()); }); @@ -351,7 +392,7 @@ static const char *inferFragType(OpTy op) { static void convertTransferReadOp(vector::TransferReadOp op, llvm::DenseMap &valueMapping) { assert(op.getTransferRank() > 0 && "unexpected 0-d transfer"); - assert(transferReadSupportsMMAMatrixType(op)); + assert(transferReadSupportsMMAMatrixType(op, /*useNvGpu=*/false)); Optional stride = getMemrefConstantHorizontalStride(op.getShapedType()); AffineMap map = op.getPermutationMap(); @@ -386,6 +427,250 @@ static void convertTransferWriteOp(vector::TransferWriteOp op, op.erase(); } +/// Returns the vector type which represents a matrix fragment. +static VectorType +getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo ®Info) { + SmallVector shape{regInfo.numRegistersPerFragment, + regInfo.elementsPerRegister}; + Type elType = regInfo.registerLLVMType; + if (auto vecType = elType.dyn_cast()) + elType = vecType.getElementType(); + return VectorType::get(shape, elType); +} + +/// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op. +static LogicalResult +convertConstantOpMmaSync(arith::ConstantOp op, + llvm::DenseMap &valueMapping) { + OpBuilder b(op); + FailureOr warpMatrixInfo = + nvgpu::getWarpMatrixInfo(op); + if (failed(warpMatrixInfo)) + return failure(); + + FailureOr regInfo = + nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); + if (failed(regInfo)) + return failure(); + + VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); + auto dense = op.getValue().dyn_cast(); + if (!dense) + return failure(); + Value result = b.create( + op.getLoc(), vectorType, + DenseElementsAttr::get(vectorType, dense.getSplatValue())); + valueMapping[op.getResult()] = result; + return success(); +} + +static LogicalResult +creatLdMatrixCompatibleLoads(vector::TransferReadOp op, OpBuilder &builder, + llvm::DenseMap &valueMapping) { + Location loc = op->getLoc(); + + FailureOr warpMatrixInfo = + nvgpu::getWarpMatrixInfo(op); + if (failed(warpMatrixInfo)) + return failure(); + + FailureOr regInfo = + nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); + if (failed(regInfo)) + return failure(); + + FailureOr params = nvgpu::getLdMatrixParams( + *warpMatrixInfo, + /*transpose=*/!op.getPermutationMap().isMinorIdentity()); + if (failed(params)) { + return op->emitError() + << "failed to convert vector.transfer_read to ldmatrix; this op " + "likely " + "should not be converted to a nvgpu.ldmatrix call."; + } + + // Adjust the load offset. + auto laneId = builder.create(loc); + FailureOr offsets = + nvgpu::getLaneIdToLdMatrixMatrixCoord(loc, builder, *params); + if (failed(offsets)) + return failure(); + + VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); + + SmallVector indices; + getXferIndices(builder, op, *offsets, {laneId}, + indices); + nvgpu::LdMatrixOp newOp = builder.create( + loc, vectorType, op.getSource(), indices, + !op.getPermutationMap().isMinorIdentity(), params->numTiles); + valueMapping[op] = newOp->getResult(0); + return success(); +} + +static LogicalResult +createNonLdMatrixLoads(vector::TransferReadOp op, OpBuilder &builder, + llvm::DenseMap &valueMapping) { + Location loc = op.getLoc(); + FailureOr warpMatrixInfo = + nvgpu::getWarpMatrixInfo(op); + if (failed(warpMatrixInfo)) + return failure(); + FailureOr regInfo = + nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); + if (failed(regInfo)) { + op->emitError() << "Failed to deduce register fragment type during " + "conversion to distributed non-ldmatrix compatible load"; + return failure(); + } + + NVVM::MMALayout targetLayout = + warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B + ? NVVM::MMALayout::col + : NVVM::MMALayout::row; + + Value laneId = builder.create(loc); + SmallVector elements; + + // This is the individual element type. + Type loadedElType = regInfo->registerLLVMType; + VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); + + Value fill = builder.create( + op.getLoc(), vectorType.getElementType(), + builder.getZeroAttr(vectorType.getElementType())); + Value result = builder.create(op.getLoc(), fill, vectorType); + + bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity(); + + // Vectorized loads. + if (!isTransposeLoad && targetLayout == NVVM::MMALayout::row) { + if (!loadedElType.isa()) { + loadedElType = VectorType::get({1}, loadedElType); + } + + for (int i = 0; i < vectorType.getShape()[0]; i++) { + FailureOr coords = nvgpu::getLaneIdAndValueIdToOperandCoord( + op.getLoc(), builder, *warpMatrixInfo); + if (failed(coords)) + return failure(); + Value logicalValueId = builder.create( + loc, builder.getIndexType(), + builder.getIndexAttr(i * regInfo->elementsPerRegister)); + SmallVector newIndices; + getXferIndices( + builder, op, *coords, {laneId, logicalValueId}, newIndices); + + Value el = builder.create(loc, loadedElType, + op.getSource(), newIndices); + result = builder.create(loc, el, result, + builder.getI64ArrayAttr(i)); + } + } else if (isTransposeLoad && targetLayout == NVVM::MMALayout::col) { + if (auto vecType = loadedElType.dyn_cast()) { + loadedElType = vecType.getElementType(); + } + // Load each element individually. + for (int i = 0; i < vectorType.getShape()[0]; i++) { + for (unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1]; + innerIdx++) { + + Value logicalValueId = builder.create( + loc, builder.getIndexType(), + builder.getIndexAttr(i * regInfo->elementsPerRegister + innerIdx)); + FailureOr coords = nvgpu::getLaneIdAndValueIdToOperandCoord( + op.getLoc(), builder, *warpMatrixInfo); + if (failed(coords)) + return failure(); + + SmallVector newIndices; + getXferIndices( + builder, op, *coords, {laneId, logicalValueId}, newIndices); + Value el = builder.create(op.getLoc(), loadedElType, + op.getSource(), newIndices); + result = builder.create( + op.getLoc(), el, result, builder.getI64ArrayAttr({i, innerIdx})); + } + } + } else { + return failure(); + } + + valueMapping[op.getResult()] = result; + return success(); +} + +/// Converts a `vector.transfer_read` operation directly to either a +/// `vector.load` or a `nvgpu.ldmatrix` operation. This function should only be +/// used when converting to `nvgpu.mma.sync` operations. +static LogicalResult +convertTransferReadToLoads(vector::TransferReadOp op, + llvm::DenseMap &valueMapping) { + OpBuilder b(op); + + FailureOr warpMatrixInfo = + nvgpu::getWarpMatrixInfo(op); + if (failed(warpMatrixInfo)) + return failure(); + + bool isLdMatrixCompatible = + op.getSource().getType().cast().getMemorySpaceAsInt() == 3 && + nvgpu::inferTileWidthInBits(*warpMatrixInfo) == 128; + + VectorType vecTy = op.getVectorType(); + int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth(); + + // When we are transposing the B operand, ldmatrix will only work if we have + // at least 8 rows to read and the width to read for the transpose is 128 + // bits. + if (!op.getPermutationMap().isMinorIdentity() && + (vecTy.getDimSize(1) < 8 || vecTy.getDimSize(0) * bitWidth < 128)) + isLdMatrixCompatible = false; + + if (!isLdMatrixCompatible) + return createNonLdMatrixLoads(op, b, valueMapping); + + return creatLdMatrixCompatibleLoads(op, b, valueMapping); +} + +static LogicalResult +convertTransferWriteToStores(vector::TransferWriteOp op, + llvm::DenseMap &valueMapping) { + OpBuilder b(op); + Location loc = op->getLoc(); + Value matrix = valueMapping.find(op.getVector())->second; + + FailureOr warpMatrixInfo = + nvgpu::getWarpMatrixInfo(op); + if (failed(warpMatrixInfo)) + return failure(); + FailureOr regInfo = + nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); + if (failed(regInfo)) + return failure(); + + VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); + Value laneId = b.create(loc); + + for (unsigned i = 0; i < vectorType.getShape()[0]; i++) { + Value logicalValueId = b.create( + loc, b.getIndexType(), + b.getIndexAttr(i * regInfo->elementsPerRegister)); + FailureOr coords = nvgpu::getLaneIdAndValueIdToOperandCoord( + op.getLoc(), b, *warpMatrixInfo); + if (failed(coords)) + return failure(); + + Value el = b.create(loc, matrix, ArrayRef{i}); + SmallVector newIndices; + getXferIndices( + b, op, *coords, {laneId, logicalValueId}, newIndices); + b.create(loc, el, op.getSource(), newIndices); + } + op->erase(); + return success(); +} + static void convertContractOp(vector::ContractionOp op, llvm::DenseMap &valueMapping) { OpBuilder b(op); @@ -397,6 +682,22 @@ static void convertContractOp(vector::ContractionOp op, valueMapping[op.getResult()] = matmul; } +static LogicalResult +convertContractOpToMmaSync(vector::ContractionOp op, + llvm::DenseMap &valueMapping) { + OpBuilder b(op); + Value opA = valueMapping.find(op.getLhs())->second; + Value opB = valueMapping.find(op.getRhs())->second; + Value opC = valueMapping.find(op.getAcc())->second; + int64_t m = op.getLhs().getType().cast().getShape()[0]; + int64_t n = op.getRhs().getType().cast().getShape()[0]; + int64_t k = op.getLhs().getType().cast().getShape()[1]; + Value matmul = b.create( + op.getLoc(), opC.getType(), opA, opB, opC, b.getI64ArrayAttr({m, n, k})); + valueMapping[op.getResult()] = matmul; + return success(); +} + /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op. static void convertConstantOp(arith::ConstantOp op, llvm::DenseMap &valueMapping) { @@ -509,13 +810,20 @@ static void convertElementwiseOp(Operation *op, gpu::MMAElementwiseOp opType, valueMapping[op->getResult(0)] = newOp; } -void mlir::populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns) { - patterns.add( - patterns.getContext()); +void mlir::populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns, + bool useNvGpu) { + if (!useNvGpu) { + patterns.add( + patterns.getContext()); + return; + } + patterns + .add( + patterns.getContext()); } void mlir::convertVectorToMMAOps(Operation *rootOp) { - SetVector ops = getOpToConvert(rootOp); + SetVector ops = getOpToConvert(rootOp, /*useNvGpu=*/false); llvm::DenseMap valueMapping; for (Operation *op : ops) { if (auto transferRead = dyn_cast(op)) { @@ -538,21 +846,71 @@ void mlir::convertVectorToMMAOps(Operation *rootOp) { } } +LogicalResult mlir::convertVectorToNVVMCompatibleMMASync(Operation *rootOp) { + SetVector ops = getOpToConvert(rootOp, /*useNvGpu=*/true); + llvm::DenseMap valueMapping; + for (Operation *op : ops) { + if (llvm::TypeSwitch(op) + .Case([&](vector::TransferReadOp transferReadOp) { + return convertTransferReadToLoads(transferReadOp, valueMapping); + }) + .Case([&](vector::TransferWriteOp transferWriteOp) { + return convertTransferWriteToStores(transferWriteOp, + valueMapping); + }) + .Case([&](vector::ContractionOp contractionOp) { + return convertContractOpToMmaSync(contractionOp, valueMapping); + }) + .Case([&](scf::ForOp forOp) { + convertForOp(forOp, valueMapping); + return success(); + }) + .Case([&](scf::YieldOp yieldOp) { + convertYieldOp(yieldOp, valueMapping); + return success(); + }) + .Case([&](arith::ConstantOp constOp) { + return convertConstantOpMmaSync(constOp, valueMapping); + }) + .Default([&](Operation *op) { + op->emitError() << "unhandled vector to mma type: " << *op; + return failure(); + }) + .failed()) { + op->emitError() << "Failed to convert op " << *op; + return failure(); + } + } + return success(); +} + namespace { struct ConvertVectorToGPUPass : public ConvertVectorToGPUBase { + + explicit ConvertVectorToGPUPass(bool useNvGpu_) { + useNvGpu.setValue(useNvGpu_); + } + void runOnOperation() override { RewritePatternSet patterns(&getContext()); - populatePrepareVectorToMMAPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + populatePrepareVectorToMMAPatterns(patterns, useNvGpu.getValue()); + if (failed( + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + return signalPassFailure(); - convertVectorToMMAOps(getOperation()); + if (useNvGpu.getValue()) { + if (failed(convertVectorToNVVMCompatibleMMASync(getOperation()))) + return signalPassFailure(); + } + + (void)convertVectorToMMAOps(getOperation()); } }; } // namespace -std::unique_ptr mlir::createConvertVectorToGPUPass() { - return std::make_unique(); +std::unique_ptr mlir::createConvertVectorToGPUPass(bool useNvGpu) { + return std::make_unique(useNvGpu); } diff --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir new file mode 100644 index 000000000000..be8d08be06ce --- /dev/null +++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir @@ -0,0 +1,349 @@ +// RUN: mlir-opt %s -split-input-file -pass-pipeline="func.func(convert-vector-to-gpu{use-nvgpu=true})" | FileCheck %s + +//######################################################### +// INT8 row-row-row +//######################################################### + +// CHECK-DAG: [[$rowA0_map:#.+]] = affine_map<()[s0] -> (s0 mod 16 + 1)> +// CHECK-DAG: [[$colA0_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 16 + 1)> + +// CHECK-DAG: [[$rowB0_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 39)> +// CHECK-DAG: [[$colB0_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 40)> +// CHECK-DAG: [[$rowB1_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 40)> +// CHECK-DAG: [[$rowB2_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 41)> +// CHECK-DAG: [[$rowB3_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 42)> +// CHECK-DAG: [[$rowB4_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 55)> +// CHECK-DAG: [[$rowB5_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 56)> +// CHECK-DAG: [[$rowB6_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 57)> +// CHECK-DAG: [[$rowB7_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 58)> + +// CHECK-DAG: [[$rowC0_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 49)> +// CHECK-DAG: [[$colC0_map:#.+]] = affine_map<()[s0] -> (s0 * 2 - (s0 floordiv 4) * 8 + 40)> +// CHECK-DAG: [[$rowC8_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 57)> + + +#map0 = affine_map<(d0, d1) -> (d1, d0)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-LABEL: func @m16n8k32_int8_row_row_row +func.func @m16n8k32_int8_row_row_row(%arg0: memref<128x128xi8, 3>, %arg1: memref<128x128xi8, 3>, %arg2: memref<128x128xi32>) { + %cst_0 = arith.constant dense<0> : vector<32x8xi8> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c17 = arith.constant 17 : index + %c39 = arith.constant 39 : index + %c40 = arith.constant 40 : index + %c49 = arith.constant 49 : index + %c50 = arith.constant 50 : index + %cst = arith.constant 0 : i8 + %cst0 = arith.constant 0 : i32 + + // Verify that the operand A is distributed to loads correctly. + + // CHECK: [[row:%.+]] = affine.apply [[$rowA0_map]]()[{{%.+}}] + // CHECK: [[col:%.+]] = affine.apply [[$colA0_map]]()[{{%.+}}] + // CHECK: nvgpu.ldmatrix %arg0[[[row]], [[col]]] {numTiles = 4 : i32, transpose = false} : memref<128x128xi8, 3> -> vector<4x4xi8> + + // Verify that the operand B is distributed to loads correctly. It's elements + // must be loaded in a non-vectorized manner to do the transpose. + + // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB0_map]]()[{{%.+}}] + // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB0_map]]()[{{%.+}}] + // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8, 3> + // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB1_map]]()[{{%.+}}] + // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB0_map]]()[{{%.+}}] + // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8, 3> + // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB2_map]]()[{{%.+}}] + // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB0_map]]()[{{%.+}}] + // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8, 3> + // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB3_map]]()[{{%.+}}] + // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB0_map]]()[{{%.+}}] + // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8, 3> + // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB0_map]]()[{{%.+}}] + // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB4_map]]()[{{%.+}}] + // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8, 3> + // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB5_map]]()[{{%.+}}] + // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB0_map]]()[{{%.+}}] + // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8, 3> + // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB6_map]]()[{{%.+}}] + // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB0_map]]()[{{%.+}}] + // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8, 3> + // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB7_map]]()[{{%.+}}] + // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB0_map]]()[{{%.+}}] + // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8, 3> + // CHECK-NOT: memref.load %arg1 + + // Verify that the operand C is distributed to loads correctly. + // CHECK: [[row:%.+]] = affine.apply [[$rowC0_map]]()[{{%.+}}] + // CHECK: [[col:%.+]] = affine.apply [[$colC0_map]]()[{{%.+}}] + // CHECK: vector.load %arg2[[[row]], [[col]]] : memref<128x128xi32>, vector<2xi32> + // CHECK: [[row:%.+]] = affine.apply [[$rowC8_map]]()[{{%.+}}] + // CHECK: [[col:%.+]] = affine.apply [[$colC0_map]]()[{{%.+}}] + // CHECK: vector.load %arg2[[[row]], [[col]]] : memref<128x128xi32>, vector<2xi32> + // CHECK-NOT: vector.load %arg2{{.*}} + + %A = vector.transfer_read %arg0[%c1, %c1], %cst {in_bounds = [true, true]} : memref<128x128xi8, 3>, vector<16x32xi8> + %B = vector.transfer_read %arg1[%c39, %c40], %cst {in_bounds = [true, true], permutation_map = #map0} : memref<128x128xi8, 3>, vector<8x32xi8> + %C = vector.transfer_read %arg2[%c49, %c40], %cst0 {in_bounds = [true, true]} : memref<128x128xi32>, vector<16x8xi32> + // CHECK: [[d:%.+]] = nvgpu.mma.sync({{.*}}) {mmaShape = [16, 8, 32]} : (vector<4x4xi8>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<16x32xi8>, vector<8x32xi8> into vector<16x8xi32> + + // CHECK: [[row:%.+]] = affine.apply [[$rowC0_map]]()[{{%.+}}] + // CHECK: [[col:%.+]] = affine.apply [[$colC0_map]]()[{{%.+}}] + // CHECK: vector.store {{%.+}}, %arg2[[[row]], [[col]]] : memref<128x128xi32>, vector<2xi32> + // CHECK: [[row:%.+]] = affine.apply [[$rowC8_map]]()[{{%.+}}] + // CHECK: [[col:%.+]] = affine.apply [[$colC0_map]]()[{{%.+}}] + // CHECK: vector.store {{%.+}}, %arg2[[[row]], [[col]]] : memref<128x128xi32>, vector<2xi32> + vector.transfer_write %D, %arg2[%c49, %c40] {in_bounds = [true, true]} : vector<16x8xi32>, memref<128x128xi32> + return +} + +// ----- + +//######################################################### +// f64 row-row-row +//######################################################### +// CHECK-DAG: [[$rowA0_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 1)> +// CHECK-DAG: [[$colA0_map:#.+]] = affine_map<()[s0] -> (s0 mod 4 + 1)> + +// CHECK-DAG: [[$rowb0_map:#.+]] = affine_map<()[s0] -> (s0 mod 4 + 39)> +// CHECK-DAG: [[$colb0_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 40)> + +// CHECK-DAG: [[$rowC0_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 49)> +// CHECK-DAG: [[$colC0_map:#.+]] = affine_map<()[s0] -> (s0 * 2 - (s0 floordiv 4) * 8 + 40) + +#map0 = affine_map<(d0, d1) -> (d1, d0)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-LABEL: func @m8n8k4_f64_row_row_row +func.func @m8n8k4_f64_row_row_row(%arg0: memref<128x128xf64>, %arg1: memref<128x128xf64>, %arg2: memref<128x128xf64>) { + %cst_0 = arith.constant dense<0.0> : vector<4x8xf64> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c17 = arith.constant 17 : index + %c39 = arith.constant 39 : index + %c40 = arith.constant 40 : index + %c49 = arith.constant 49 : index + %c50 = arith.constant 50 : index + %cst = arith.constant 0.0 : f64 + %cst0 = arith.constant 0.0 : f64 + + // Verify that the operand A is distributed to loads correctly. + + // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA0_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA0_map]] + // CHECK: vector.load %arg0[[[row]], [[col]]] : memref<128x128xf64>, vector<1xf64> + + // Verify that the operand B is distributed to loads correctly. It's elements + // must be loaded in a non-vectorized manner to do the transpose. + + // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowb0_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[$colb0_map]] + // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xf64> + + // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowC0_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[$colC0_map]] + // CHECK: vector.load %arg2[[[row]], [[col]]] : memref<128x128xf64>, vector<2xf64> + + %A = vector.transfer_read %arg0[%c1, %c1], %cst {in_bounds = [true, true]} : memref<128x128xf64>, vector<8x4xf64> + %B = vector.transfer_read %arg1[%c39, %c40], %cst {in_bounds = [true, true], permutation_map = #map0} : memref<128x128xf64>, vector<8x4xf64> + %C = vector.transfer_read %arg2[%c49, %c40], %cst0 {in_bounds = [true, true]} : memref<128x128xf64>, vector<8x8xf64> + // CHECK: [[d:%.+]] = nvgpu.mma.sync({{.*}}) {mmaShape = [8, 8, 4]} : (vector<1x1xf64>, vector<1x1xf64>, vector<1x2xf64>) -> vector<1x2xf64> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<8x4xf64>, vector<8x4xf64> into vector<8x8xf64> + + // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowC0_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[$colC0_map]] + // CHECK: vector.store {{%.+}}, %arg2[[[row]], [[col]]] : memref<128x128xf64>, vector<2xf64> + vector.transfer_write %D, %arg2[%c49, %c40] {in_bounds = [true, true]} : vector<8x8xf64>, memref<128x128xf64> + return +} + +// ----- + +//######################################################### +// FP16 row-row-row +//######################################################### + +#map0 = affine_map<(d0, d1) -> (d1, d0)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-DAG: [[$rowA_map:#.+]] = affine_map<()[s0] -> (s0 mod 16 + 1)> +// CHECK-DAG: [[$colA_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8 + 3)> + +// CHECK-DAG: [[$rowB_map:#.+]] = affine_map<()[s0] -> (s0 + 3)> +// CHECK-DAG: [[$colB_map:#.+]] = affine_map<() -> (3)> + +// CHECK-LABEL: func @m16n8k16_fp16_row_row_row +func.func @m16n8k16_fp16_row_row_row(%arg0: memref<20x20xf16, 3>, %arg1: memref<20x20xf16, 3>, %arg2: memref<20x20xf16, 3>) { + %cst_0 = arith.constant dense<0.000000e+00> : vector<16x8xf16> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %cst = arith.constant 0.000000e+00 : f16 + // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA_map]] + // CHECK: nvgpu.ldmatrix %arg0[[[row]], [[col]]] {numTiles = 4 : i32, transpose = false} + + // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB_map]] + // CHECK: nvgpu.ldmatrix %arg1[[[row]], [[col]]] {numTiles = 2 : i32, transpose = true} + %A = vector.transfer_read %arg0[%c1, %c3], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<16x16xf16> + %B = vector.transfer_read %arg1[%c3, %c3], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<8x16xf16> + %C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<16x8xf16> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16> + vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf16>, memref<20x20xf16, 3> + return +} + +// ----- + +// CHECK-DAG: [[$Arow_map:#.+]] = affine_map<()[s0] -> (s0 mod 16 + 1)> +// CHECK-DAG: [[$Acol_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8 + 3)> +// CHECK-DAG: [[$Bcol_map:#.+]] = affine_map<() -> (3)> +// CHECK-DAG: [[$Brow_map:#.+]] = affine_map<()[s0] -> (s0 + 3)> + +#map0 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-LABEL: func @batch_m16n8k16_fp16_row_row_row +func.func @batch_m16n8k16_fp16_row_row_row(%arg0: memref<2x20x20xf16, 3>, %arg1: memref<2x20x20xf16, 3>, %arg2: memref<2x20x20xf16, 3>) { + %cst_0 = arith.constant dense<0.000000e+00> : vector<20x20xf16> + // CHECK: [[C0:%.+]] = arith.constant 0 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %cst = arith.constant 0.000000e+00 : f16 + + // CHECK-DAG: [[row:%.+]] = affine.apply [[$Arow_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[$Acol_map]] + // CHECK: nvgpu.ldmatrix %arg0[[[C0]], [[row]], [[col]]] {numTiles = 4 : i32, transpose = false} : memref<2x20x20xf16, 3> -> vector<4x2xf16> + %A = vector.transfer_read %arg0[%c0, %c1, %c3], %cst {in_bounds = [true, true]} : memref<2x20x20xf16, 3>, vector<16x16xf16> + + // CHECK-DAG: [[row:%.+]] = affine.apply [[$Brow_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[$Bcol_map]] + // CHECK: nvgpu.ldmatrix %arg1[[[C0]], [[row]], [[col]]] {numTiles = 2 : i32, transpose = true} : memref<2x20x20xf16, 3> -> vector<2x2xf16> + %B = vector.transfer_read %arg1[%c0, %c3, %c3], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<2x20x20xf16, 3>, vector<8x16xf16> + + // CHECK-DAG: [[row:%.+]] = affine.apply [[$Arow_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[$Acol_map]] + // CHECK: nvgpu.ldmatrix %arg2[[[C0]], [[row]], [[col]]] {numTiles = 2 : i32, transpose = false} : memref<2x20x20xf16, 3> -> vector<2x2xf16> + %C = vector.transfer_read %arg2[%c0, %c1, %c3], %cst {in_bounds = [true, true]} : memref<2x20x20xf16, 3>, vector<16x8xf16> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16> + vector.transfer_write %D, %arg2[%c0, %c1, %c3] {in_bounds = [true, true]} : vector<16x8xf16>, memref<2x20x20xf16, 3> + return +} + +// ----- + +//######################################################### +// FP16 row-col-row +//######################################################### + +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK: [[$rowA_map:#.+]] = affine_map<()[s0] -> (s0 mod 16 + 1)> +// CHECK: [[$colA_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8 + 3)> + +// CHECK: [[$rowB_map:#.+]] = affine_map<()[s0] -> (s0 mod 8 + 1)> +// CHECK: [[$colB_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 8) * 8 + 3)> + +// CHECK-LABEL: func @m16n8k16_fp16_row_col_row +func.func @m16n8k16_fp16_row_col_row(%arg0: memref<20x20xf16, 3>, %arg1: memref<20x20xf16, 3>, %arg2: memref<20x20xf16, 3>) { + %cst_0 = arith.constant dense<0.000000e+00> : vector<16x8xf16> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %cst = arith.constant 0.000000e+00 : f16 + // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA_map]] + // CHECK: nvgpu.ldmatrix %arg0[[[row]], [[col]]] {numTiles = 4 : i32 + // CHECK-SAME: transpose = false + + // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB_map]] + // CHECK: nvgpu.ldmatrix %arg1[[[row]], [[col]]] {numTiles = 2 : i32 + // CHECK-SAME: transpose = false + + // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA_map]] + // CHECK: nvgpu.ldmatrix %arg2[[[row]], [[col]]] {numTiles = 2 : i32 + // CHECK-SAME: transpose = false + %A = vector.transfer_read %arg0[%c1, %c3], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<16x16xf16> + %B = vector.transfer_read %arg1[%c1, %c3], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<8x16xf16> + %C = vector.transfer_read %arg2[%c1, %c3], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<16x8xf16> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16> + vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf16>, memref<20x20xf16, 3> + return +} + +// ----- + +//######################################################### +// TF32 (multiplicand) F32 (accumulator) row-row-row +//######################################################### + +#map0 = affine_map<(d0, d1) -> (d1, d0)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-DAG: [[$rowA_map:#.+]] = affine_map<()[s0] -> (s0 mod 16 + 1)> +// CHECK-DAG: [[$colA_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 4 + 3)> + +// CHECK-DAG: [[$rowB_map:#.+]] = affine_map<()[s0] -> (s0 mod 4 + 3)> +// CHECK-DAG: [[$colB_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 3)> + +// CHECK-DAG: [[$rowC_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4)> +// CHECK-DAG: [[$rowC8_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 8)> +// CHECK-DAG: [[$colC_map:#.+]] = affine_map<()[s0] -> (s0 * 2 - (s0 floordiv 4) * 8)> + +// CHECK-LABEL: func @m16n8k4_tf32_f32_row_row_row +func.func @m16n8k4_tf32_f32_row_row_row(%arg0: memref<20x20xf32, 3>, %arg1: memref<20x20xf32, 3>, %arg2: memref<20x20xf32>) { + %cst_0 = arith.constant dense<0.000000e+00> : vector<16x8xf32> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %cst = arith.constant 0.000000e+00 : f32 + + // CHECK: [[c_frag:%.+]] = arith.constant {{.*}} : vector<2x2xf32> + + // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[$colA_map]] + // CHECK: [[a_frag:%.+]] = nvgpu.ldmatrix %arg0[[[row]], [[col]]] {numTiles = 2 : i32, transpose = false} + + // b and c are not loaded by ldmatrix in this test. + // CHECK-NOT: nvgpu.ldmatrix + + // CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[$colB_map]] + // CHECK: [[b_el:%.+]] = memref.load {{%.+}} : memref<20x20xf32, 3> + // CHECK: [[b_frag:%.+]] = vector.insert [[b_el]], {{.*}} : f32 into vector<1x1xf32> + + // CHECK: [[d_frag:%.+]] = nvgpu.mma.sync([[a_frag]], [[b_frag]], [[c_frag]]) + // CHECK-SAME: mmaShape = [16, 8, 4] + // CHECK-SAME: -> vector<2x2xf32> + %A = vector.transfer_read %arg0[%c1, %c3], %cst {in_bounds = [true, true]} : memref<20x20xf32, 3>, vector<16x4xf32> + %B = vector.transfer_read %arg1[%c3, %c3], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<20x20xf32, 3>, vector<8x4xf32> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %cst_0 : vector<16x4xf32>, vector<8x4xf32> into vector<16x8xf32> + + // CHECK: vector.extract [[d_frag]][0] : vector<2x2xf32> + // CHECK: affine.apply [[$rowC_map]] + // CHECK: affine.apply [[$colC_map]] + // CHECK: vector.store + // CHECK: vector.extract [[d_frag]][1] : vector<2x2xf32> + // CHECK: affine.apply [[$rowC8_map]] + // CHECK: affine.apply [[$colC_map]] + // CHECK: vector.store + vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf32>, memref<20x20xf32> + return +}