[mlir][VectorToGPU] Add conversion for splat constant to MMA const matrix

Differential Revision: https://reviews.llvm.org/D104133
This commit is contained in:
thomasraoux 2021-06-24 15:29:49 -07:00
parent 473a3a773e
commit 6413226dce
2 changed files with 51 additions and 2 deletions

View File

@ -113,6 +113,15 @@ transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) {
return true;
}
/// Return true if the constant is a splat to a 2D vector so that it can be
/// converted to a MMA constant matrix op.
static bool constantSupportsMMAMatrixType(ConstantOp constantOp) {
auto vecType = constantOp.getType().dyn_cast<VectorType>();
if (!vecType || vecType.getRank() != 2)
return false;
return constantOp.value().isa<SplatElementsAttr>();
}
static bool supportsMMaMatrixType(Operation *op) {
if (auto transferRead = dyn_cast<vector::TransferReadOp>(op))
return transferReadSupportsMMAMatrixType(transferRead);
@ -120,6 +129,8 @@ static bool supportsMMaMatrixType(Operation *op) {
return transferWriteSupportsMMAMatrixType(transferWrite);
if (auto contract = dyn_cast<vector::ContractionOp>(op))
return contractSupportsMMAMatrixType(contract);
if (auto constant = dyn_cast<ConstantOp>(op))
return constantSupportsMMAMatrixType(constant);
return false;
}
@ -241,10 +252,11 @@ struct CombineTransferReadOpTranspose final
} // namespace
// MMA types have different layout based on how they are used in matmul ops.
// Figure the right layout to use by looking at Transfer op uses.
// Figure the right layout to use by looking at op uses.
// TODO: Change the GPU dialect to abstract the layout at the this level and
// only care about it during lowering to NVVM.
static const char *inferFragType(vector::TransferReadOp op) {
template <typename OpTy>
static const char *inferFragType(OpTy op) {
for (Operation *users : op->getUsers()) {
auto contract = dyn_cast<vector::ContractionOp>(users);
if (!contract)
@ -297,6 +309,23 @@ static void convertContractOp(vector::ContractionOp op,
valueMapping[op.getResult()] = matmul;
}
/// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
static void convertConstantOp(ConstantOp op,
llvm::DenseMap<Value, Value> &valueMapping) {
assert(constantSupportsMMAMatrixType(op));
OpBuilder b(op);
Attribute splat = op.getValue().cast<SplatElementsAttr>().getSplatValue();
auto scalarConstant =
b.create<ConstantOp>(op.getLoc(), splat.getType(), splat);
const char *fragType = inferFragType(op);
auto vecType = op.getType().cast<VectorType>();
gpu::MMAMatrixType type = gpu::MMAMatrixType::get(
vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
auto matrix = b.create<gpu::SubgroupMmaConstantMatrixOp>(op.getLoc(), type,
scalarConstant);
valueMapping[op.getResult()] = matrix;
}
namespace mlir {
void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns) {
@ -314,6 +343,8 @@ void convertVectorToMMAOps(FuncOp funcOp) {
convertTransferWriteOp(transferWrite, valueMapping);
} else if (auto contractOp = dyn_cast<vector::ContractionOp>(op)) {
convertContractOp(contractOp, valueMapping);
} else if (auto constantOp = dyn_cast<ConstantOp>(op)) {
convertConstantOp(constantOp, valueMapping);
}
}
}

View File

@ -23,6 +23,24 @@ func @matmul(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<1
return
}
// CHECK-LABEL: func @matmul_cst
// CHECK-DAG: %[[CST:.+]] = constant 0.000000e+00 : f16
// CHECK-DAG: %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
// CHECK-DAG: %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%c0, %c0] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
// CHECK-DAG: %[[C:.+]] = gpu.subgroup_mma_constant_matrix %[[CST]] : !gpu.mma_matrix<16x16xf16, "COp">
// CHECK: %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C]] : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
// CHECK: gpu.subgroup_mma_store_matrix %[[D]], %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16>
func @matmul_cst(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<16x16xf16>) {
%cst_0 = constant dense<0.000000e+00> : vector<16x16xf16>
%c0 = constant 0 : index
%cst = constant 0.000000e+00 : f16
%A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
%B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
%D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %cst_0 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16>
return
}
// Negative test until scf.for support is added.
// CHECK-LABEL: func @matmul_loop
// CHECK: vector.contract