[MLIR][GPU] Add NvGpu mma.sync path to the VectorToGPU pass

This changes adds the option to lower to NvGpu dialect ops during the
VectorToGPU convsersion pass. Because this transformation reuses
existing VectorToGPU logic, a seperate VectorToNvGpu conversion pass is
not created. The option `use-nvgpu` is added to the VectorToGPU pass.
When this is true, the pass will attempt to convert slices rooted at
`vector.contract` operations into `nvgpu.mma.sync` ops, and
`vector.transfer_read` ops are converted to either `nvgpu.ldmatrix` or
one or more `vector.load` operations.  The specific data loaded will
depend on the thread id within a subgroup (warp). These index
calculations depend on data type and shape of the MMA op
according to the downstream PTX specification. The code for supporting
these details is separated into `NvGpuSupport.cpp|h`.

Differential Revision: https://reviews.llvm.org/D122940
This commit is contained in:
Christopher Bate 2022-05-17 17:54:29 -06:00
parent 480dcdc897
commit 1ca772ed95
8 changed files with 1181 additions and 28 deletions

View File

@ -851,8 +851,13 @@ def ConvertVectorToGPU : Pass<"convert-vector-to-gpu"> {
"dialect"; "dialect";
let constructor = "mlir::createConvertVectorToGPUPass()"; let constructor = "mlir::createConvertVectorToGPUPass()";
let dependentDialects = [ let dependentDialects = [
"memref::MemRefDialect", "memref::MemRefDialect", "gpu::GPUDialect", "AffineDialect",
"gpu::GPUDialect" "vector::VectorDialect", "nvgpu::NVGPUDialect"
];
let options = [
Option<"useNvGpu", "use-nvgpu", "bool", /*default=*/"false",
"convert to NvGPU ops instead of GPU dialect ops">
]; ];
} }

View File

@ -17,16 +17,25 @@ class Pass;
class RewritePatternSet; class RewritePatternSet;
/// Patterns to transform vector ops into a canonical form to convert to MMA /// Patterns to transform vector ops into a canonical form to convert to MMA
/// matrix operations. /// matrix operations. If `useNvGpu` is true, then the patterns will populated
void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns); /// will prepare for conversion to `nvgpu` mma operations rather than the `gpu`
/// dialect WMMA operations.
void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns,
bool useNvGpu = false);
/// Convert vector ops to MMA matrix operations nested under `rootOp`. This will /// Convert vector ops to MMA matrix operations nested under `rootOp`. This will
/// convert slice of operations that can be legally converted to MMA operations. /// convert slice of operations that can be legally converted to MMA operations.
/// The rest of the vector operations are left untouched. /// The rest of the vector operations are left untouched.
void convertVectorToMMAOps(Operation *rootOp); void convertVectorToMMAOps(Operation *rootOp);
/// Convert vector ops ops nested under `rootOp` to vector and GPU operaitons
/// compatible with the `nvvm.mma.sync` lowering path. This will convert a slice
/// of operations that can be legally lowered on this path while the rest of
/// the vector operations are left untouched.
LogicalResult convertVectorToNVVMCompatibleMMASync(Operation *rootOp);
/// Convert from vector to GPU ops. /// Convert from vector to GPU ops.
std::unique_ptr<Pass> createConvertVectorToGPUPass(); std::unique_ptr<Pass> createConvertVectorToGPUPass(bool useNvGpu = false);
} // namespace mlir } // namespace mlir

View File

@ -55,6 +55,10 @@ namespace LLVM {
class LLVMDialect; class LLVMDialect;
} // namespace LLVM } // namespace LLVM
namespace nvgpu {
class NVGPUDialect;
}
namespace NVVM { namespace NVVM {
class NVVMDialect; class NVVMDialect;
} // namespace NVVM } // namespace NVVM

View File

@ -1,5 +1,6 @@
add_mlir_conversion_library(MLIRVectorToGPU add_mlir_conversion_library(MLIRVectorToGPU
VectorToGPU.cpp VectorToGPU.cpp
NvGpuSupport.cpp
ADDITIONAL_HEADER_DIRS ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToGPU ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToGPU

View File

@ -0,0 +1,327 @@
//===- NvGpuSupport.cpp - MLIR Vector to GPU lowering support --------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file provides utilities to assist in the lowering of Vector operations
// to NvGPU dialect MMA operations.
//
//===----------------------------------------------------------------------===//
#include "NvGpuSupport.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/NVGPU/NVGPUDialect.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
namespace mlir {
namespace nvgpu {
namespace {
/// There are always 4 threads per [128|256|512] bit row.
constexpr int64_t kThreadsPerRow = 4;
constexpr int64_t kNumRowsPerTile = 8;
bool isAccumulatorOrResult(MatMulOperandRole operandType) {
return operandType == MatMulOperandRole::C;
}
/// Returns the number of registers which compose a matrix fragment held by a
/// single thread.
int64_t inferNumRegistersPerMatrixFragment(const WarpMatrixInfo &type) {
int64_t lineSize = inferTileWidthInBits(type);
auto shape = type.vectorType.getShape();
return (shape[0] / kNumRowsPerTile) *
(shape[1] * type.vectorType.getElementType().getIntOrFloatBitWidth()) /
lineSize;
}
/// Returns the number of 8 x [128|256|512] bit tiles that compose the given
/// operand shape.
std::array<int64_t, 2> getTileShape(ArrayRef<int64_t> operandShape,
Type elementType, int64_t lineSizeBits) {
// For each 8x128bit square, a thread is responsible for one 32bit register.
return {operandShape[0] / kNumRowsPerTile,
(operandShape[1] * elementType.getIntOrFloatBitWidth()) /
lineSizeBits};
}
} // namespace
FailureOr<WarpMatrixInfo> getWarpMatrixInfo(Operation *op) {
WarpMatrixInfo info;
// Determine the vector type.
if (vector::TransferWriteOp writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
info.vectorType = writeOp.getVectorType();
} else if (isa<vector::TransferReadOp, vector::ContractionOp,
arith::ConstantOp>(op)) {
info.vectorType = op->getResult(0).getType().cast<VectorType>();
} else {
return op->emitError()
<< "unhandled operation type in nvgpu.mma.sync conversion path";
}
// Determine the operand role. We assume it is an accumulator/result unless it
// is directly consumed by a `vector.contract` op.
info.operandRole = MatMulOperandRole::C;
for (Operation *user : op->getUsers()) {
auto contract = dyn_cast<vector::ContractionOp>(user);
if (!contract)
continue;
if (contract.getLhs() == op->getResult(0)) {
info.operandRole = MatMulOperandRole::A;
break;
}
if (contract.getRhs() == op->getResult(0)) {
info.operandRole = MatMulOperandRole::B;
break;
}
}
return info;
}
int64_t inferTileWidthInBits(const WarpMatrixInfo &type) {
bool isAcc = isAccumulatorOrResult(type.operandRole);
Type elType = type.vectorType.getElementType();
if (isAcc && elType.getIntOrFloatBitWidth() == 32) {
return 256;
}
if (elType.getIntOrFloatBitWidth() == 64) {
return isAcc ? 512 : 256;
}
return 128;
}
FailureOr<FragmentElementInfo>
getMmaSyncRegisterType(const WarpMatrixInfo &type) {
MLIRContext *ctx = type.vectorType.getContext();
const bool isAccum = isAccumulatorOrResult(type.operandRole);
Type elType = type.vectorType.getElementType();
if (elType.isF16()) {
return FragmentElementInfo{
LLVM::getFixedVectorType(Float16Type::get(ctx), 2), 2, 32,
inferNumRegistersPerMatrixFragment(type)};
}
// f64 operand
Type f64Ty = Float64Type::get(ctx);
if (elType.isF64()) {
return isAccum
? FragmentElementInfo{LLVM::getFixedVectorType(f64Ty, 2), 2, 128,
inferNumRegistersPerMatrixFragment(type)}
: FragmentElementInfo{f64Ty, 1, 64,
inferNumRegistersPerMatrixFragment(type)};
}
// int8 operand
if (elType.isInteger(8)) {
return FragmentElementInfo{
LLVM::getFixedVectorType(IntegerType::get(ctx, 8), 4), 4, 32,
inferNumRegistersPerMatrixFragment(type)};
}
// Integer 32bit acc operands
if (elType.isInteger(32)) {
return FragmentElementInfo{
LLVM::getFixedVectorType(IntegerType::get(ctx, 32), 2), 2, 64,
inferNumRegistersPerMatrixFragment(type)};
}
// Floating point 32bit operands
if (elType.isF32()) {
Type f32Ty = Float32Type::get(ctx);
return isAccum
? FragmentElementInfo{LLVM::getFixedVectorType(f32Ty, 2), 2, 64,
inferNumRegistersPerMatrixFragment(type)}
: FragmentElementInfo{f32Ty, 1, 32,
inferNumRegistersPerMatrixFragment(type)};
}
return failure();
}
static AffineMap getRegisterIndexToTileOffsetMap(int64_t lineSize,
Type elementType,
ArrayRef<int64_t> operandShape,
bool isAccumulator,
int64_t elementsPerRegister,
AffineExpr logicalValueId) {
const int64_t elementsPerLine =
lineSize / elementType.getIntOrFloatBitWidth();
const std::array<int64_t, 2> num8x128bTiles =
getTileShape(operandShape, elementType, lineSize);
AffineExpr registerIdx = logicalValueId.floorDiv(elementsPerRegister);
return AffineMap::get(
2, 0,
{(registerIdx % num8x128bTiles[0]) * 8,
(registerIdx.floorDiv(num8x128bTiles[0])) * elementsPerLine},
elementType.getContext());
}
FailureOr<AffineMap>
getLaneIdAndValueIdToOperandCoord(Location loc, OpBuilder &builder,
const WarpMatrixInfo &fragmentType) {
Type elementType = fragmentType.vectorType.getElementType();
ArrayRef<int64_t> operandShape = fragmentType.vectorType.getShape();
FailureOr<nvgpu::FragmentElementInfo> regInfo =
getMmaSyncRegisterType(fragmentType);
if (failed(regInfo))
return failure();
const int64_t elementBitWidth = elementType.getIntOrFloatBitWidth();
const int64_t elementsPerRegister =
regInfo->registerWidthBits / elementBitWidth;
const int64_t lineSize = inferTileWidthInBits(fragmentType);
AffineExpr laneId, logicalValueIdDim;
bindDims(builder.getContext(), laneId, logicalValueIdDim);
// Determine what register logicalValueId corresponds to. Use that as a
// linear index into the coordinate mapping `index -> (tile row, tile col)`.
AffineMap registerIndexToTileCoord = getRegisterIndexToTileOffsetMap(
lineSize, elementType, operandShape,
isAccumulatorOrResult(fragmentType.operandRole), elementsPerRegister,
logicalValueIdDim);
auto makeMap = [&](ArrayRef<AffineExpr> dimExprs) -> AffineMap {
return AffineMap::get(2, 0, dimExprs, builder.getContext());
};
auto tileRow = registerIndexToTileCoord.getResult(0);
auto tileCol = registerIndexToTileCoord.getResult(1);
return makeMap({tileRow + laneId.floorDiv(kThreadsPerRow),
tileCol + (laneId % kThreadsPerRow) * elementsPerRegister +
(logicalValueIdDim % elementsPerRegister)});
}
FailureOr<nvgpu::LdMatrixParams> getLdMatrixParams(const WarpMatrixInfo &type,
bool transpose) {
LdMatrixParams params;
Type elType = type.vectorType.getElementType();
params.fragmentType = type.vectorType;
if (type.operandRole == MatMulOperandRole::A ||
type.operandRole == MatMulOperandRole::C) {
params.targetLayout = NVVM::MMALayout::row;
} else {
params.targetLayout = NVVM::MMALayout::col;
}
ArrayRef<int64_t> shape = type.vectorType.getShape();
params.contiguousDimType =
transpose ? IteratorType::Parallel : IteratorType::Reduction;
if (params.targetLayout == NVVM::MMALayout::row) {
params.numTiles = (shape[0] / kNumRowsPerTile) *
((shape[1] * elType.getIntOrFloatBitWidth()) / 128);
} else {
params.numTiles = (shape[1] / kNumRowsPerTile) *
((shape[0] * elType.getIntOrFloatBitWidth()) / 128);
}
if (params.numTiles == 0)
return failure();
return params;
}
FailureOr<AffineMap>
getLaneIdToLdMatrixMatrixCoord(Location loc, OpBuilder &builder,
const LdMatrixParams &params) {
// One thread per 128b row.
const int64_t kNumThreadsPerTile = kNumRowsPerTile;
const int bitsPerElement = static_cast<int>(
params.fragmentType.getElementType().getIntOrFloatBitWidth());
const int kElementsPer128b = (128 / bitsPerElement);
ArrayRef<int64_t> operandShape = params.fragmentType.getShape();
AffineExpr d0 = getAffineDimExpr(0, builder.getContext());
auto makeMap = [&](ArrayRef<AffineExpr> dimExprs) -> AffineMap {
return AffineMap::get(1, 0, dimExprs, builder.getContext());
};
// This case corresponds to row-major A|C or col-major B operands.
if (params.contiguousDimType == IteratorType::Reduction) {
AffineExpr row = d0 % (operandShape[0]);
AffineExpr col = d0.floorDiv(operandShape[0]) * (kElementsPer128b);
return makeMap({row, col});
}
// This case Corresponds to col-major A|C or row-major B operands. The
// operandShape given is already pre-transposed (e.g. 8x16 = KxN).
if (params.contiguousDimType == IteratorType::Parallel) {
const int64_t num8x128bCols = (operandShape[0] * bitsPerElement) / 128;
// Threads are assigned in groups of 8 first across columns, then to
// rows. This is transpose of what `ldmatrix` expects, but when
// `ldmatrix` gets the `.trans` qualifier, final the effect will be to
// transpose just the blocks.
auto groupIdx = d0.floorDiv(kNumThreadsPerTile);
auto tileCol = (groupIdx % num8x128bCols);
auto tileRow = groupIdx.floorDiv(num8x128bCols);
return makeMap({tileCol * kElementsPer128b,
tileRow * kNumRowsPerTile + (d0 % kNumRowsPerTile)});
}
return failure();
}
LogicalResult
PrepareContractToGPUMMASync::matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const {
Location loc = op.getLoc();
Value lhs = op.getLhs();
Value rhs = op.getRhs();
Value res = op.getAcc();
// Set up the parallel/reduction structure in right form.
using MapList = ArrayRef<ArrayRef<AffineExpr>>;
auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
AffineExpr m;
AffineExpr n;
AffineExpr k;
bindDims(rewriter.getContext(), m, n, k);
static constexpr std::array<int64_t, 2> perm = {1, 0};
auto iteratorTypes = op.getIteratorTypes().getValue();
SmallVector<AffineMap, 4> maps = op.getIndexingMaps();
if (iteratorTypes.size() != 3)
return failure();
if (!(isParallelIterator(iteratorTypes[0]) &&
isParallelIterator(iteratorTypes[1]) &&
isReductionIterator(iteratorTypes[2])))
return failure();
// The canonical form is "TNT" = A row-major, B col-major, C row-major.
const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}});
if (maps == canonicalForm) {
return failure();
}
if (maps == infer({{m, k}, {k, n}, {m, n}})) {
rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
} else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
} else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
} else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
std::swap(rhs, lhs);
rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
} else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
std::swap(rhs, lhs);
rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
} else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
std::swap(lhs, rhs);
lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
} else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
std::swap(lhs, rhs);
} else {
return failure();
}
rewriter.replaceOpWithNewOp<vector::ContractionOp>(
op, lhs, rhs, res, rewriter.getAffineMapArrayAttr(canonicalForm),
op.getIteratorTypes());
return success();
}
} // namespace nvgpu
} // namespace mlir

View File

@ -0,0 +1,100 @@
//===- NvvmMMASupport.h - MLIR Vector to GPU lowering support --------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file provides utilities to assist in the lowering of Vector operations
// to GPU dialect MMA operations.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_VECTORTOGPU_NVGPUSUPPORT_H
#define MLIR_CONVERSION_VECTORTOGPU_NVGPUSUPPORT_H
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Types.h"
namespace mlir {
namespace nvgpu {
enum class MatMulOperandRole : int32_t { A = 0, B, C };
/// Collects information about a warp-level matrix operand represented by a
/// VectorType.
struct WarpMatrixInfo {
VectorType vectorType;
MatMulOperandRole operandRole;
};
/// Given an op that operates on a VectorType representing a warp-level matrix
/// operand, the function returns a struct containing relevant type information.
FailureOr<WarpMatrixInfo> getWarpMatrixInfo(Operation *op);
/// Returns the number of bits in a single tile row. It is either 128, 256, or
/// 512 bits depending on the data type and` whether the operand is an
/// accumulator/result operand
int64_t inferTileWidthInBits(const WarpMatrixInfo &type);
/// Specifies information about the registers which compose a matrix fragment
/// according to the PTX documentation.
struct FragmentElementInfo {
Type registerLLVMType;
int64_t elementsPerRegister;
int64_t registerWidthBits;
int64_t numRegistersPerFragment;
};
/// Returns a FragmentElementInfo struct describing the register types for the
/// given matrix fragment type.
FailureOr<FragmentElementInfo>
getMmaSyncRegisterType(const WarpMatrixInfo &type);
/// Returns an AffineMap which maps a two dimensions representing (laneId,
/// logicalValueId) and returns two results representing offsets within a
/// matrix operand. The offsets point to the values the thread is responsible
/// for (AKA the matrix fragment values) during a warp-collective matrix
/// operation. For a visual reference of this LaneId -> (row, col) mapping,
/// please see NVIDIA's PTX documentation:
/// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-mma
FailureOr<AffineMap>
getLaneIdAndValueIdToOperandCoord(Location loc, OpBuilder &builder,
const WarpMatrixInfo &fragmentType);
struct LdMatrixParams {
VectorType fragmentType;
bool isAccum;
int64_t numTiles;
IteratorType contiguousDimType;
NVVM::MMALayout targetLayout;
};
FailureOr<LdMatrixParams> getLdMatrixParams(const WarpMatrixInfo &type,
bool transpose);
/// Returns an AffineMap which maps a single dimension representing the laneId
/// to two results representing offsets within the matrix operand that should
/// be the pointer locations a thread should pass to the ldmatrix instruction.
FailureOr<AffineMap>
getLaneIdToLdMatrixMatrixCoord(Location loc, OpBuilder &builder,
const LdMatrixParams &params);
// Transform contract into (m, k)x(n, k)x(m, n) form so that it can be converted
// to MMA matmul.
struct PrepareContractToGPUMMASync
: public OpRewritePattern<vector::ContractionOp> {
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const override;
};
} // namespace nvgpu
} // namespace mlir
#endif // MLIR_CONVERSION_VECTORTOGPU_NVGPUSUPPORT_H

View File

@ -12,6 +12,7 @@
#include <type_traits> #include <type_traits>
#include "NvGpuSupport.h"
#include "mlir/Conversion/VectorToGPU/VectorToGPU.h" #include "mlir/Conversion/VectorToGPU/VectorToGPU.h"
#include "../PassDetail.h" #include "../PassDetail.h"
@ -19,6 +20,7 @@
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/NVGPU/NVGPUDialect.h"
#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/IR/VectorOps.h"
@ -27,11 +29,39 @@
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Passes.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir; using namespace mlir;
/// For a vector TransferOpType `xferOp`, an empty `indices` vector, and an
/// AffineMap representing offsets to apply to indices, the function fills
/// `indices` with the original indices plus the offsets. The offsets are
/// applied by taking into account the permutation map of the transfer op. If
/// the `offsetMap` has dimension placeholders, those should be provided in
/// `dimValues`.
template <typename TransferOpType>
static void getXferIndices(OpBuilder &b, TransferOpType xferOp,
AffineMap offsetMap, ArrayRef<Value> dimValues,
SmallVector<Value, 4> &indices) {
indices.append(xferOp.getIndices().begin(), xferOp.getIndices().end());
Location loc = xferOp.getLoc();
unsigned offsetsIdx = 0;
for (auto expr : xferOp.getPermutationMap().getResults()) {
if (auto dim = expr.template dyn_cast<AffineDimExpr>()) {
Value prevIdx = indices[dim.getPosition()];
SmallVector<Value, 3> dims(dimValues.begin(), dimValues.end());
dims.push_back(prevIdx);
AffineExpr d0 = b.getAffineDimExpr(offsetMap.getNumDims());
indices[dim.getPosition()] = makeComposedAffineApply(
b, loc, d0 + offsetMap.getResult(offsetsIdx++), dims);
continue;
}
}
}
// Return true if the contract op can be convert to MMA matmul. // Return true if the contract op can be convert to MMA matmul.
static bool contractSupportsMMAMatrixType(vector::ContractionOp contract) { static bool contractSupportsMMAMatrixType(vector::ContractionOp contract,
bool useNvGpu) {
if (llvm::size(contract.getMasks()) != 0) if (llvm::size(contract.getMasks()) != 0)
return false; return false;
@ -47,7 +77,10 @@ static bool contractSupportsMMAMatrixType(vector::ContractionOp contract) {
// The contract needs to represent a matmul to be able to convert to // The contract needs to represent a matmul to be able to convert to
// MMAMatrix matmul. // MMAMatrix matmul.
if (contract.getIndexingMaps() != infer({{m, k}, {k, n}, {m, n}})) if (!useNvGpu &&
contract.getIndexingMaps() != infer({{m, k}, {k, n}, {m, n}}))
return false;
if (useNvGpu && contract.getIndexingMaps() != infer({{m, k}, {n, k}, {m, n}}))
return false; return false;
return true; return true;
@ -75,7 +108,8 @@ getMemrefConstantHorizontalStride(ShapedType type) {
} }
// Return true if the transfer op can be converted to a MMA matrix load. // Return true if the transfer op can be converted to a MMA matrix load.
static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) { static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp,
bool useNvGpu) {
if (readOp.getMask() || readOp.hasOutOfBoundsDim() || if (readOp.getMask() || readOp.hasOutOfBoundsDim() ||
readOp.getVectorType().getRank() != 2) readOp.getVectorType().getRank() != 2)
return false; return false;
@ -87,9 +121,14 @@ static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) {
AffineExpr zero = b.getAffineConstantExpr(0); AffineExpr zero = b.getAffineConstantExpr(0);
auto broadcastInnerDim = AffineMap::get(map.getNumDims(), 0, {zero, innerDim}, auto broadcastInnerDim = AffineMap::get(map.getNumDims(), 0, {zero, innerDim},
readOp.getContext()); readOp.getContext());
if (!useNvGpu) {
// TODO: Support transpose once it is added to GPU dialect ops. // TODO: Support transpose once it is added to GPU dialect ops.
// For now we only support (d0, d1) -> (d0, d1) and (d0, d1) -> (0, d1). // For now we only support (d0, d1) -> (d0, d1) and (d0, d1) -> (0, d1).
return !(!map.isMinorIdentity() && map != broadcastInnerDim); return map.isMinorIdentity() || map == broadcastInnerDim;
}
return true;
} }
// Return true if the transfer op can be converted to a MMA matrix store. // Return true if the transfer op can be converted to a MMA matrix store.
@ -147,15 +186,15 @@ static bool elementwiseSupportsMMAMatrixType(Operation *op) {
return convertElementwiseOpToMMA(op).hasValue(); return convertElementwiseOpToMMA(op).hasValue();
} }
static bool supportsMMaMatrixType(Operation *op) { static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
if (isa<scf::ForOp, scf::YieldOp>(op)) if (isa<scf::ForOp, scf::YieldOp>(op))
return true; return true;
if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) if (auto transferRead = dyn_cast<vector::TransferReadOp>(op))
return transferReadSupportsMMAMatrixType(transferRead); return transferReadSupportsMMAMatrixType(transferRead, useNvGpu);
if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
return transferWriteSupportsMMAMatrixType(transferWrite); return transferWriteSupportsMMAMatrixType(transferWrite);
if (auto contract = dyn_cast<vector::ContractionOp>(op)) if (auto contract = dyn_cast<vector::ContractionOp>(op))
return contractSupportsMMAMatrixType(contract); return contractSupportsMMAMatrixType(contract, useNvGpu);
if (auto constant = dyn_cast<arith::ConstantOp>(op)) if (auto constant = dyn_cast<arith::ConstantOp>(op))
return constantSupportsMMAMatrixType(constant); return constantSupportsMMAMatrixType(constant);
if (auto broadcast = dyn_cast<vector::BroadcastOp>(op)) if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
@ -203,7 +242,8 @@ static SetVector<Operation *> getSliceContract(Operation *op,
// Analyze slice of operations based on convert op to figure out if the whole // Analyze slice of operations based on convert op to figure out if the whole
// slice can be converted to MMA operations. // slice can be converted to MMA operations.
static SetVector<Operation *> getOpToConvert(mlir::Operation *op) { static SetVector<Operation *> getOpToConvert(mlir::Operation *op,
bool useNvGpu) {
auto hasVectorDest = [](Operation *op) { auto hasVectorDest = [](Operation *op) {
return llvm::any_of(op->getResultTypes(), return llvm::any_of(op->getResultTypes(),
[](Type t) { return t.isa<VectorType>(); }); [](Type t) { return t.isa<VectorType>(); });
@ -221,8 +261,9 @@ static SetVector<Operation *> getOpToConvert(mlir::Operation *op) {
// If any instruction cannot use MMA matrix type drop the whole // If any instruction cannot use MMA matrix type drop the whole
// chain. MMA matrix are stored in an opaque type so they cannot be used // chain. MMA matrix are stored in an opaque type so they cannot be used
// by all operations. // by all operations.
if (llvm::any_of(dependentOps, if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) {
[](Operation *op) { return !supportsMMaMatrixType(op); })) return !supportsMMaMatrixType(op, useNvGpu);
}))
return; return;
opToConvert.insert(dependentOps.begin(), dependentOps.end()); opToConvert.insert(dependentOps.begin(), dependentOps.end());
}); });
@ -351,7 +392,7 @@ static const char *inferFragType(OpTy op) {
static void convertTransferReadOp(vector::TransferReadOp op, static void convertTransferReadOp(vector::TransferReadOp op,
llvm::DenseMap<Value, Value> &valueMapping) { llvm::DenseMap<Value, Value> &valueMapping) {
assert(op.getTransferRank() > 0 && "unexpected 0-d transfer"); assert(op.getTransferRank() > 0 && "unexpected 0-d transfer");
assert(transferReadSupportsMMAMatrixType(op)); assert(transferReadSupportsMMAMatrixType(op, /*useNvGpu=*/false));
Optional<int64_t> stride = Optional<int64_t> stride =
getMemrefConstantHorizontalStride(op.getShapedType()); getMemrefConstantHorizontalStride(op.getShapedType());
AffineMap map = op.getPermutationMap(); AffineMap map = op.getPermutationMap();
@ -386,6 +427,250 @@ static void convertTransferWriteOp(vector::TransferWriteOp op,
op.erase(); op.erase();
} }
/// Returns the vector type which represents a matrix fragment.
static VectorType
getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo &regInfo) {
SmallVector<int64_t> shape{regInfo.numRegistersPerFragment,
regInfo.elementsPerRegister};
Type elType = regInfo.registerLLVMType;
if (auto vecType = elType.dyn_cast<VectorType>())
elType = vecType.getElementType();
return VectorType::get(shape, elType);
}
/// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
static LogicalResult
convertConstantOpMmaSync(arith::ConstantOp op,
llvm::DenseMap<Value, Value> &valueMapping) {
OpBuilder b(op);
FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
nvgpu::getWarpMatrixInfo(op);
if (failed(warpMatrixInfo))
return failure();
FailureOr<nvgpu::FragmentElementInfo> regInfo =
nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
if (failed(regInfo))
return failure();
VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
auto dense = op.getValue().dyn_cast<SplatElementsAttr>();
if (!dense)
return failure();
Value result = b.create<arith::ConstantOp>(
op.getLoc(), vectorType,
DenseElementsAttr::get(vectorType, dense.getSplatValue<Attribute>()));
valueMapping[op.getResult()] = result;
return success();
}
static LogicalResult
creatLdMatrixCompatibleLoads(vector::TransferReadOp op, OpBuilder &builder,
llvm::DenseMap<Value, Value> &valueMapping) {
Location loc = op->getLoc();
FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
nvgpu::getWarpMatrixInfo(op);
if (failed(warpMatrixInfo))
return failure();
FailureOr<nvgpu::FragmentElementInfo> regInfo =
nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
if (failed(regInfo))
return failure();
FailureOr<nvgpu::LdMatrixParams> params = nvgpu::getLdMatrixParams(
*warpMatrixInfo,
/*transpose=*/!op.getPermutationMap().isMinorIdentity());
if (failed(params)) {
return op->emitError()
<< "failed to convert vector.transfer_read to ldmatrix; this op "
"likely "
"should not be converted to a nvgpu.ldmatrix call.";
}
// Adjust the load offset.
auto laneId = builder.create<gpu::LaneIdOp>(loc);
FailureOr<AffineMap> offsets =
nvgpu::getLaneIdToLdMatrixMatrixCoord(loc, builder, *params);
if (failed(offsets))
return failure();
VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
SmallVector<Value, 4> indices;
getXferIndices<vector::TransferReadOp>(builder, op, *offsets, {laneId},
indices);
nvgpu::LdMatrixOp newOp = builder.create<nvgpu::LdMatrixOp>(
loc, vectorType, op.getSource(), indices,
!op.getPermutationMap().isMinorIdentity(), params->numTiles);
valueMapping[op] = newOp->getResult(0);
return success();
}
static LogicalResult
createNonLdMatrixLoads(vector::TransferReadOp op, OpBuilder &builder,
llvm::DenseMap<Value, Value> &valueMapping) {
Location loc = op.getLoc();
FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
nvgpu::getWarpMatrixInfo(op);
if (failed(warpMatrixInfo))
return failure();
FailureOr<nvgpu::FragmentElementInfo> regInfo =
nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
if (failed(regInfo)) {
op->emitError() << "Failed to deduce register fragment type during "
"conversion to distributed non-ldmatrix compatible load";
return failure();
}
NVVM::MMALayout targetLayout =
warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B
? NVVM::MMALayout::col
: NVVM::MMALayout::row;
Value laneId = builder.create<gpu::LaneIdOp>(loc);
SmallVector<Value, 4> elements;
// This is the individual element type.
Type loadedElType = regInfo->registerLLVMType;
VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
Value fill = builder.create<arith::ConstantOp>(
op.getLoc(), vectorType.getElementType(),
builder.getZeroAttr(vectorType.getElementType()));
Value result = builder.create<vector::SplatOp>(op.getLoc(), fill, vectorType);
bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();
// Vectorized loads.
if (!isTransposeLoad && targetLayout == NVVM::MMALayout::row) {
if (!loadedElType.isa<VectorType>()) {
loadedElType = VectorType::get({1}, loadedElType);
}
for (int i = 0; i < vectorType.getShape()[0]; i++) {
FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
op.getLoc(), builder, *warpMatrixInfo);
if (failed(coords))
return failure();
Value logicalValueId = builder.create<arith::ConstantOp>(
loc, builder.getIndexType(),
builder.getIndexAttr(i * regInfo->elementsPerRegister));
SmallVector<Value, 4> newIndices;
getXferIndices<vector::TransferReadOp>(
builder, op, *coords, {laneId, logicalValueId}, newIndices);
Value el = builder.create<vector::LoadOp>(loc, loadedElType,
op.getSource(), newIndices);
result = builder.create<vector::InsertOp>(loc, el, result,
builder.getI64ArrayAttr(i));
}
} else if (isTransposeLoad && targetLayout == NVVM::MMALayout::col) {
if (auto vecType = loadedElType.dyn_cast<VectorType>()) {
loadedElType = vecType.getElementType();
}
// Load each element individually.
for (int i = 0; i < vectorType.getShape()[0]; i++) {
for (unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1];
innerIdx++) {
Value logicalValueId = builder.create<arith::ConstantOp>(
loc, builder.getIndexType(),
builder.getIndexAttr(i * regInfo->elementsPerRegister + innerIdx));
FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
op.getLoc(), builder, *warpMatrixInfo);
if (failed(coords))
return failure();
SmallVector<Value, 4> newIndices;
getXferIndices<vector::TransferReadOp>(
builder, op, *coords, {laneId, logicalValueId}, newIndices);
Value el = builder.create<memref::LoadOp>(op.getLoc(), loadedElType,
op.getSource(), newIndices);
result = builder.create<vector::InsertOp>(
op.getLoc(), el, result, builder.getI64ArrayAttr({i, innerIdx}));
}
}
} else {
return failure();
}
valueMapping[op.getResult()] = result;
return success();
}
/// Converts a `vector.transfer_read` operation directly to either a
/// `vector.load` or a `nvgpu.ldmatrix` operation. This function should only be
/// used when converting to `nvgpu.mma.sync` operations.
static LogicalResult
convertTransferReadToLoads(vector::TransferReadOp op,
llvm::DenseMap<Value, Value> &valueMapping) {
OpBuilder b(op);
FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
nvgpu::getWarpMatrixInfo(op);
if (failed(warpMatrixInfo))
return failure();
bool isLdMatrixCompatible =
op.getSource().getType().cast<MemRefType>().getMemorySpaceAsInt() == 3 &&
nvgpu::inferTileWidthInBits(*warpMatrixInfo) == 128;
VectorType vecTy = op.getVectorType();
int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth();
// When we are transposing the B operand, ldmatrix will only work if we have
// at least 8 rows to read and the width to read for the transpose is 128
// bits.
if (!op.getPermutationMap().isMinorIdentity() &&
(vecTy.getDimSize(1) < 8 || vecTy.getDimSize(0) * bitWidth < 128))
isLdMatrixCompatible = false;
if (!isLdMatrixCompatible)
return createNonLdMatrixLoads(op, b, valueMapping);
return creatLdMatrixCompatibleLoads(op, b, valueMapping);
}
static LogicalResult
convertTransferWriteToStores(vector::TransferWriteOp op,
llvm::DenseMap<Value, Value> &valueMapping) {
OpBuilder b(op);
Location loc = op->getLoc();
Value matrix = valueMapping.find(op.getVector())->second;
FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
nvgpu::getWarpMatrixInfo(op);
if (failed(warpMatrixInfo))
return failure();
FailureOr<nvgpu::FragmentElementInfo> regInfo =
nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
if (failed(regInfo))
return failure();
VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
Value laneId = b.create<gpu::LaneIdOp>(loc);
for (unsigned i = 0; i < vectorType.getShape()[0]; i++) {
Value logicalValueId = b.create<arith::ConstantOp>(
loc, b.getIndexType(),
b.getIndexAttr(i * regInfo->elementsPerRegister));
FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
op.getLoc(), b, *warpMatrixInfo);
if (failed(coords))
return failure();
Value el = b.create<vector::ExtractOp>(loc, matrix, ArrayRef<int64_t>{i});
SmallVector<Value, 4> newIndices;
getXferIndices<vector::TransferWriteOp>(
b, op, *coords, {laneId, logicalValueId}, newIndices);
b.create<vector::StoreOp>(loc, el, op.getSource(), newIndices);
}
op->erase();
return success();
}
static void convertContractOp(vector::ContractionOp op, static void convertContractOp(vector::ContractionOp op,
llvm::DenseMap<Value, Value> &valueMapping) { llvm::DenseMap<Value, Value> &valueMapping) {
OpBuilder b(op); OpBuilder b(op);
@ -397,6 +682,22 @@ static void convertContractOp(vector::ContractionOp op,
valueMapping[op.getResult()] = matmul; valueMapping[op.getResult()] = matmul;
} }
static LogicalResult
convertContractOpToMmaSync(vector::ContractionOp op,
llvm::DenseMap<Value, Value> &valueMapping) {
OpBuilder b(op);
Value opA = valueMapping.find(op.getLhs())->second;
Value opB = valueMapping.find(op.getRhs())->second;
Value opC = valueMapping.find(op.getAcc())->second;
int64_t m = op.getLhs().getType().cast<VectorType>().getShape()[0];
int64_t n = op.getRhs().getType().cast<VectorType>().getShape()[0];
int64_t k = op.getLhs().getType().cast<VectorType>().getShape()[1];
Value matmul = b.create<nvgpu::MmaSyncOp>(
op.getLoc(), opC.getType(), opA, opB, opC, b.getI64ArrayAttr({m, n, k}));
valueMapping[op.getResult()] = matmul;
return success();
}
/// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op. /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
static void convertConstantOp(arith::ConstantOp op, static void convertConstantOp(arith::ConstantOp op,
llvm::DenseMap<Value, Value> &valueMapping) { llvm::DenseMap<Value, Value> &valueMapping) {
@ -509,13 +810,20 @@ static void convertElementwiseOp(Operation *op, gpu::MMAElementwiseOp opType,
valueMapping[op->getResult(0)] = newOp; valueMapping[op->getResult(0)] = newOp;
} }
void mlir::populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns) { void mlir::populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns,
bool useNvGpu) {
if (!useNvGpu) {
patterns.add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>( patterns.add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>(
patterns.getContext()); patterns.getContext());
return;
}
patterns
.add<nvgpu::PrepareContractToGPUMMASync, CombineTransferReadOpTranspose>(
patterns.getContext());
} }
void mlir::convertVectorToMMAOps(Operation *rootOp) { void mlir::convertVectorToMMAOps(Operation *rootOp) {
SetVector<Operation *> ops = getOpToConvert(rootOp); SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/false);
llvm::DenseMap<Value, Value> valueMapping; llvm::DenseMap<Value, Value> valueMapping;
for (Operation *op : ops) { for (Operation *op : ops) {
if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) { if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) {
@ -538,21 +846,71 @@ void mlir::convertVectorToMMAOps(Operation *rootOp) {
} }
} }
LogicalResult mlir::convertVectorToNVVMCompatibleMMASync(Operation *rootOp) {
SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/true);
llvm::DenseMap<Value, Value> valueMapping;
for (Operation *op : ops) {
if (llvm::TypeSwitch<Operation *, LogicalResult>(op)
.Case([&](vector::TransferReadOp transferReadOp) {
return convertTransferReadToLoads(transferReadOp, valueMapping);
})
.Case([&](vector::TransferWriteOp transferWriteOp) {
return convertTransferWriteToStores(transferWriteOp,
valueMapping);
})
.Case([&](vector::ContractionOp contractionOp) {
return convertContractOpToMmaSync(contractionOp, valueMapping);
})
.Case([&](scf::ForOp forOp) {
convertForOp(forOp, valueMapping);
return success();
})
.Case([&](scf::YieldOp yieldOp) {
convertYieldOp(yieldOp, valueMapping);
return success();
})
.Case([&](arith::ConstantOp constOp) {
return convertConstantOpMmaSync(constOp, valueMapping);
})
.Default([&](Operation *op) {
op->emitError() << "unhandled vector to mma type: " << *op;
return failure();
})
.failed()) {
op->emitError() << "Failed to convert op " << *op;
return failure();
}
}
return success();
}
namespace { namespace {
struct ConvertVectorToGPUPass struct ConvertVectorToGPUPass
: public ConvertVectorToGPUBase<ConvertVectorToGPUPass> { : public ConvertVectorToGPUBase<ConvertVectorToGPUPass> {
explicit ConvertVectorToGPUPass(bool useNvGpu_) {
useNvGpu.setValue(useNvGpu_);
}
void runOnOperation() override { void runOnOperation() override {
RewritePatternSet patterns(&getContext()); RewritePatternSet patterns(&getContext());
populatePrepareVectorToMMAPatterns(patterns); populatePrepareVectorToMMAPatterns(patterns, useNvGpu.getValue());
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
convertVectorToMMAOps(getOperation()); if (useNvGpu.getValue()) {
if (failed(convertVectorToNVVMCompatibleMMASync(getOperation())))
return signalPassFailure();
}
(void)convertVectorToMMAOps(getOperation());
} }
}; };
} // namespace } // namespace
std::unique_ptr<Pass> mlir::createConvertVectorToGPUPass() { std::unique_ptr<Pass> mlir::createConvertVectorToGPUPass(bool useNvGpu) {
return std::make_unique<ConvertVectorToGPUPass>(); return std::make_unique<ConvertVectorToGPUPass>(useNvGpu);
} }

View File

@ -0,0 +1,349 @@
// RUN: mlir-opt %s -split-input-file -pass-pipeline="func.func(convert-vector-to-gpu{use-nvgpu=true})" | FileCheck %s
//#########################################################
// INT8 row-row-row
//#########################################################
// CHECK-DAG: [[$rowA0_map:#.+]] = affine_map<()[s0] -> (s0 mod 16 + 1)>
// CHECK-DAG: [[$colA0_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 16 + 1)>
// CHECK-DAG: [[$rowB0_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 39)>
// CHECK-DAG: [[$colB0_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 40)>
// CHECK-DAG: [[$rowB1_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 40)>
// CHECK-DAG: [[$rowB2_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 41)>
// CHECK-DAG: [[$rowB3_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 42)>
// CHECK-DAG: [[$rowB4_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 55)>
// CHECK-DAG: [[$rowB5_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 56)>
// CHECK-DAG: [[$rowB6_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 57)>
// CHECK-DAG: [[$rowB7_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 58)>
// CHECK-DAG: [[$rowC0_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 49)>
// CHECK-DAG: [[$colC0_map:#.+]] = affine_map<()[s0] -> (s0 * 2 - (s0 floordiv 4) * 8 + 40)>
// CHECK-DAG: [[$rowC8_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 57)>
#map0 = affine_map<(d0, d1) -> (d1, d0)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-LABEL: func @m16n8k32_int8_row_row_row
func.func @m16n8k32_int8_row_row_row(%arg0: memref<128x128xi8, 3>, %arg1: memref<128x128xi8, 3>, %arg2: memref<128x128xi32>) {
%cst_0 = arith.constant dense<0> : vector<32x8xi8>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c17 = arith.constant 17 : index
%c39 = arith.constant 39 : index
%c40 = arith.constant 40 : index
%c49 = arith.constant 49 : index
%c50 = arith.constant 50 : index
%cst = arith.constant 0 : i8
%cst0 = arith.constant 0 : i32
// Verify that the operand A is distributed to loads correctly.
// CHECK: [[row:%.+]] = affine.apply [[$rowA0_map]]()[{{%.+}}]
// CHECK: [[col:%.+]] = affine.apply [[$colA0_map]]()[{{%.+}}]
// CHECK: nvgpu.ldmatrix %arg0[[[row]], [[col]]] {numTiles = 4 : i32, transpose = false} : memref<128x128xi8, 3> -> vector<4x4xi8>
// Verify that the operand B is distributed to loads correctly. It's elements
// must be loaded in a non-vectorized manner to do the transpose.
// CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB0_map]]()[{{%.+}}]
// CHECK-DAG: [[col:%.+]] = affine.apply [[$colB0_map]]()[{{%.+}}]
// CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8, 3>
// CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB1_map]]()[{{%.+}}]
// CHECK-DAG: [[col:%.+]] = affine.apply [[$colB0_map]]()[{{%.+}}]
// CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8, 3>
// CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB2_map]]()[{{%.+}}]
// CHECK-DAG: [[col:%.+]] = affine.apply [[$colB0_map]]()[{{%.+}}]
// CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8, 3>
// CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB3_map]]()[{{%.+}}]
// CHECK-DAG: [[col:%.+]] = affine.apply [[$colB0_map]]()[{{%.+}}]
// CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8, 3>
// CHECK-DAG: [[col:%.+]] = affine.apply [[$colB0_map]]()[{{%.+}}]
// CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB4_map]]()[{{%.+}}]
// CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8, 3>
// CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB5_map]]()[{{%.+}}]
// CHECK-DAG: [[col:%.+]] = affine.apply [[$colB0_map]]()[{{%.+}}]
// CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8, 3>
// CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB6_map]]()[{{%.+}}]
// CHECK-DAG: [[col:%.+]] = affine.apply [[$colB0_map]]()[{{%.+}}]
// CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8, 3>
// CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB7_map]]()[{{%.+}}]
// CHECK-DAG: [[col:%.+]] = affine.apply [[$colB0_map]]()[{{%.+}}]
// CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8, 3>
// CHECK-NOT: memref.load %arg1
// Verify that the operand C is distributed to loads correctly.
// CHECK: [[row:%.+]] = affine.apply [[$rowC0_map]]()[{{%.+}}]
// CHECK: [[col:%.+]] = affine.apply [[$colC0_map]]()[{{%.+}}]
// CHECK: vector.load %arg2[[[row]], [[col]]] : memref<128x128xi32>, vector<2xi32>
// CHECK: [[row:%.+]] = affine.apply [[$rowC8_map]]()[{{%.+}}]
// CHECK: [[col:%.+]] = affine.apply [[$colC0_map]]()[{{%.+}}]
// CHECK: vector.load %arg2[[[row]], [[col]]] : memref<128x128xi32>, vector<2xi32>
// CHECK-NOT: vector.load %arg2{{.*}}
%A = vector.transfer_read %arg0[%c1, %c1], %cst {in_bounds = [true, true]} : memref<128x128xi8, 3>, vector<16x32xi8>
%B = vector.transfer_read %arg1[%c39, %c40], %cst {in_bounds = [true, true], permutation_map = #map0} : memref<128x128xi8, 3>, vector<8x32xi8>
%C = vector.transfer_read %arg2[%c49, %c40], %cst0 {in_bounds = [true, true]} : memref<128x128xi32>, vector<16x8xi32>
// CHECK: [[d:%.+]] = nvgpu.mma.sync({{.*}}) {mmaShape = [16, 8, 32]} : (vector<4x4xi8>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
%D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x32xi8>, vector<8x32xi8> into vector<16x8xi32>
// CHECK: [[row:%.+]] = affine.apply [[$rowC0_map]]()[{{%.+}}]
// CHECK: [[col:%.+]] = affine.apply [[$colC0_map]]()[{{%.+}}]
// CHECK: vector.store {{%.+}}, %arg2[[[row]], [[col]]] : memref<128x128xi32>, vector<2xi32>
// CHECK: [[row:%.+]] = affine.apply [[$rowC8_map]]()[{{%.+}}]
// CHECK: [[col:%.+]] = affine.apply [[$colC0_map]]()[{{%.+}}]
// CHECK: vector.store {{%.+}}, %arg2[[[row]], [[col]]] : memref<128x128xi32>, vector<2xi32>
vector.transfer_write %D, %arg2[%c49, %c40] {in_bounds = [true, true]} : vector<16x8xi32>, memref<128x128xi32>
return
}
// -----
//#########################################################
// f64 row-row-row
//#########################################################
// CHECK-DAG: [[$rowA0_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 1)>
// CHECK-DAG: [[$colA0_map:#.+]] = affine_map<()[s0] -> (s0 mod 4 + 1)>
// CHECK-DAG: [[$rowb0_map:#.+]] = affine_map<()[s0] -> (s0 mod 4 + 39)>
// CHECK-DAG: [[$colb0_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 40)>
// CHECK-DAG: [[$rowC0_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 49)>
// CHECK-DAG: [[$colC0_map:#.+]] = affine_map<()[s0] -> (s0 * 2 - (s0 floordiv 4) * 8 + 40)
#map0 = affine_map<(d0, d1) -> (d1, d0)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-LABEL: func @m8n8k4_f64_row_row_row
func.func @m8n8k4_f64_row_row_row(%arg0: memref<128x128xf64>, %arg1: memref<128x128xf64>, %arg2: memref<128x128xf64>) {
%cst_0 = arith.constant dense<0.0> : vector<4x8xf64>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c17 = arith.constant 17 : index
%c39 = arith.constant 39 : index
%c40 = arith.constant 40 : index
%c49 = arith.constant 49 : index
%c50 = arith.constant 50 : index
%cst = arith.constant 0.0 : f64
%cst0 = arith.constant 0.0 : f64
// Verify that the operand A is distributed to loads correctly.
// CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA0_map]]
// CHECK-DAG: [[col:%.+]] = affine.apply [[$colA0_map]]
// CHECK: vector.load %arg0[[[row]], [[col]]] : memref<128x128xf64>, vector<1xf64>
// Verify that the operand B is distributed to loads correctly. It's elements
// must be loaded in a non-vectorized manner to do the transpose.
// CHECK-DAG: [[row:%.+]] = affine.apply [[$rowb0_map]]
// CHECK-DAG: [[col:%.+]] = affine.apply [[$colb0_map]]
// CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xf64>
// CHECK-DAG: [[row:%.+]] = affine.apply [[$rowC0_map]]
// CHECK-DAG: [[col:%.+]] = affine.apply [[$colC0_map]]
// CHECK: vector.load %arg2[[[row]], [[col]]] : memref<128x128xf64>, vector<2xf64>
%A = vector.transfer_read %arg0[%c1, %c1], %cst {in_bounds = [true, true]} : memref<128x128xf64>, vector<8x4xf64>
%B = vector.transfer_read %arg1[%c39, %c40], %cst {in_bounds = [true, true], permutation_map = #map0} : memref<128x128xf64>, vector<8x4xf64>
%C = vector.transfer_read %arg2[%c49, %c40], %cst0 {in_bounds = [true, true]} : memref<128x128xf64>, vector<8x8xf64>
// CHECK: [[d:%.+]] = nvgpu.mma.sync({{.*}}) {mmaShape = [8, 8, 4]} : (vector<1x1xf64>, vector<1x1xf64>, vector<1x2xf64>) -> vector<1x2xf64>
%D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<8x4xf64>, vector<8x4xf64> into vector<8x8xf64>
// CHECK-DAG: [[row:%.+]] = affine.apply [[$rowC0_map]]
// CHECK-DAG: [[col:%.+]] = affine.apply [[$colC0_map]]
// CHECK: vector.store {{%.+}}, %arg2[[[row]], [[col]]] : memref<128x128xf64>, vector<2xf64>
vector.transfer_write %D, %arg2[%c49, %c40] {in_bounds = [true, true]} : vector<8x8xf64>, memref<128x128xf64>
return
}
// -----
//#########################################################
// FP16 row-row-row
//#########################################################
#map0 = affine_map<(d0, d1) -> (d1, d0)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-DAG: [[$rowA_map:#.+]] = affine_map<()[s0] -> (s0 mod 16 + 1)>
// CHECK-DAG: [[$colA_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8 + 3)>
// CHECK-DAG: [[$rowB_map:#.+]] = affine_map<()[s0] -> (s0 + 3)>
// CHECK-DAG: [[$colB_map:#.+]] = affine_map<() -> (3)>
// CHECK-LABEL: func @m16n8k16_fp16_row_row_row
func.func @m16n8k16_fp16_row_row_row(%arg0: memref<20x20xf16, 3>, %arg1: memref<20x20xf16, 3>, %arg2: memref<20x20xf16, 3>) {
%cst_0 = arith.constant dense<0.000000e+00> : vector<16x8xf16>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c3 = arith.constant 3 : index
%cst = arith.constant 0.000000e+00 : f16
// CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA_map]]
// CHECK-DAG: [[col:%.+]] = affine.apply [[$colA_map]]
// CHECK: nvgpu.ldmatrix %arg0[[[row]], [[col]]] {numTiles = 4 : i32, transpose = false}
// CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB_map]]
// CHECK-DAG: [[col:%.+]] = affine.apply [[$colB_map]]
// CHECK: nvgpu.ldmatrix %arg1[[[row]], [[col]]] {numTiles = 2 : i32, transpose = true}
%A = vector.transfer_read %arg0[%c1, %c3], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<16x16xf16>
%B = vector.transfer_read %arg1[%c3, %c3], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<8x16xf16>
%C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<16x8xf16>
%D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16>
vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf16>, memref<20x20xf16, 3>
return
}
// -----
// CHECK-DAG: [[$Arow_map:#.+]] = affine_map<()[s0] -> (s0 mod 16 + 1)>
// CHECK-DAG: [[$Acol_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8 + 3)>
// CHECK-DAG: [[$Bcol_map:#.+]] = affine_map<() -> (3)>
// CHECK-DAG: [[$Brow_map:#.+]] = affine_map<()[s0] -> (s0 + 3)>
#map0 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-LABEL: func @batch_m16n8k16_fp16_row_row_row
func.func @batch_m16n8k16_fp16_row_row_row(%arg0: memref<2x20x20xf16, 3>, %arg1: memref<2x20x20xf16, 3>, %arg2: memref<2x20x20xf16, 3>) {
%cst_0 = arith.constant dense<0.000000e+00> : vector<20x20xf16>
// CHECK: [[C0:%.+]] = arith.constant 0 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c3 = arith.constant 3 : index
%cst = arith.constant 0.000000e+00 : f16
// CHECK-DAG: [[row:%.+]] = affine.apply [[$Arow_map]]
// CHECK-DAG: [[col:%.+]] = affine.apply [[$Acol_map]]
// CHECK: nvgpu.ldmatrix %arg0[[[C0]], [[row]], [[col]]] {numTiles = 4 : i32, transpose = false} : memref<2x20x20xf16, 3> -> vector<4x2xf16>
%A = vector.transfer_read %arg0[%c0, %c1, %c3], %cst {in_bounds = [true, true]} : memref<2x20x20xf16, 3>, vector<16x16xf16>
// CHECK-DAG: [[row:%.+]] = affine.apply [[$Brow_map]]
// CHECK-DAG: [[col:%.+]] = affine.apply [[$Bcol_map]]
// CHECK: nvgpu.ldmatrix %arg1[[[C0]], [[row]], [[col]]] {numTiles = 2 : i32, transpose = true} : memref<2x20x20xf16, 3> -> vector<2x2xf16>
%B = vector.transfer_read %arg1[%c0, %c3, %c3], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<2x20x20xf16, 3>, vector<8x16xf16>
// CHECK-DAG: [[row:%.+]] = affine.apply [[$Arow_map]]
// CHECK-DAG: [[col:%.+]] = affine.apply [[$Acol_map]]
// CHECK: nvgpu.ldmatrix %arg2[[[C0]], [[row]], [[col]]] {numTiles = 2 : i32, transpose = false} : memref<2x20x20xf16, 3> -> vector<2x2xf16>
%C = vector.transfer_read %arg2[%c0, %c1, %c3], %cst {in_bounds = [true, true]} : memref<2x20x20xf16, 3>, vector<16x8xf16>
%D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16>
vector.transfer_write %D, %arg2[%c0, %c1, %c3] {in_bounds = [true, true]} : vector<16x8xf16>, memref<2x20x20xf16, 3>
return
}
// -----
//#########################################################
// FP16 row-col-row
//#########################################################
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK: [[$rowA_map:#.+]] = affine_map<()[s0] -> (s0 mod 16 + 1)>
// CHECK: [[$colA_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8 + 3)>
// CHECK: [[$rowB_map:#.+]] = affine_map<()[s0] -> (s0 mod 8 + 1)>
// CHECK: [[$colB_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 8) * 8 + 3)>
// CHECK-LABEL: func @m16n8k16_fp16_row_col_row
func.func @m16n8k16_fp16_row_col_row(%arg0: memref<20x20xf16, 3>, %arg1: memref<20x20xf16, 3>, %arg2: memref<20x20xf16, 3>) {
%cst_0 = arith.constant dense<0.000000e+00> : vector<16x8xf16>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c3 = arith.constant 3 : index
%cst = arith.constant 0.000000e+00 : f16
// CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA_map]]
// CHECK-DAG: [[col:%.+]] = affine.apply [[$colA_map]]
// CHECK: nvgpu.ldmatrix %arg0[[[row]], [[col]]] {numTiles = 4 : i32
// CHECK-SAME: transpose = false
// CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB_map]]
// CHECK-DAG: [[col:%.+]] = affine.apply [[$colB_map]]
// CHECK: nvgpu.ldmatrix %arg1[[[row]], [[col]]] {numTiles = 2 : i32
// CHECK-SAME: transpose = false
// CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA_map]]
// CHECK-DAG: [[col:%.+]] = affine.apply [[$colA_map]]
// CHECK: nvgpu.ldmatrix %arg2[[[row]], [[col]]] {numTiles = 2 : i32
// CHECK-SAME: transpose = false
%A = vector.transfer_read %arg0[%c1, %c3], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<16x16xf16>
%B = vector.transfer_read %arg1[%c1, %c3], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<8x16xf16>
%C = vector.transfer_read %arg2[%c1, %c3], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<16x8xf16>
%D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16>
vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf16>, memref<20x20xf16, 3>
return
}
// -----
//#########################################################
// TF32 (multiplicand) F32 (accumulator) row-row-row
//#########################################################
#map0 = affine_map<(d0, d1) -> (d1, d0)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-DAG: [[$rowA_map:#.+]] = affine_map<()[s0] -> (s0 mod 16 + 1)>
// CHECK-DAG: [[$colA_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 4 + 3)>
// CHECK-DAG: [[$rowB_map:#.+]] = affine_map<()[s0] -> (s0 mod 4 + 3)>
// CHECK-DAG: [[$colB_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 3)>
// CHECK-DAG: [[$rowC_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4)>
// CHECK-DAG: [[$rowC8_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 8)>
// CHECK-DAG: [[$colC_map:#.+]] = affine_map<()[s0] -> (s0 * 2 - (s0 floordiv 4) * 8)>
// CHECK-LABEL: func @m16n8k4_tf32_f32_row_row_row
func.func @m16n8k4_tf32_f32_row_row_row(%arg0: memref<20x20xf32, 3>, %arg1: memref<20x20xf32, 3>, %arg2: memref<20x20xf32>) {
%cst_0 = arith.constant dense<0.000000e+00> : vector<16x8xf32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c3 = arith.constant 3 : index
%cst = arith.constant 0.000000e+00 : f32
// CHECK: [[c_frag:%.+]] = arith.constant {{.*}} : vector<2x2xf32>
// CHECK-DAG: [[row:%.+]] = affine.apply [[$rowA_map]]
// CHECK-DAG: [[col:%.+]] = affine.apply [[$colA_map]]
// CHECK: [[a_frag:%.+]] = nvgpu.ldmatrix %arg0[[[row]], [[col]]] {numTiles = 2 : i32, transpose = false}
// b and c are not loaded by ldmatrix in this test.
// CHECK-NOT: nvgpu.ldmatrix
// CHECK-DAG: [[row:%.+]] = affine.apply [[$rowB_map]]
// CHECK-DAG: [[col:%.+]] = affine.apply [[$colB_map]]
// CHECK: [[b_el:%.+]] = memref.load {{%.+}} : memref<20x20xf32, 3>
// CHECK: [[b_frag:%.+]] = vector.insert [[b_el]], {{.*}} : f32 into vector<1x1xf32>
// CHECK: [[d_frag:%.+]] = nvgpu.mma.sync([[a_frag]], [[b_frag]], [[c_frag]])
// CHECK-SAME: mmaShape = [16, 8, 4]
// CHECK-SAME: -> vector<2x2xf32>
%A = vector.transfer_read %arg0[%c1, %c3], %cst {in_bounds = [true, true]} : memref<20x20xf32, 3>, vector<16x4xf32>
%B = vector.transfer_read %arg1[%c3, %c3], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<20x20xf32, 3>, vector<8x4xf32>
%D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %cst_0 : vector<16x4xf32>, vector<8x4xf32> into vector<16x8xf32>
// CHECK: vector.extract [[d_frag]][0] : vector<2x2xf32>
// CHECK: affine.apply [[$rowC_map]]
// CHECK: affine.apply [[$colC_map]]
// CHECK: vector.store
// CHECK: vector.extract [[d_frag]][1] : vector<2x2xf32>
// CHECK: affine.apply [[$rowC8_map]]
// CHECK: affine.apply [[$colC_map]]
// CHECK: vector.store
vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf32>, memref<20x20xf32>
return
}