forked from OSchip/llvm-project
[mlir][VectorToGPU] Add conversion for splat constant to MMA const matrix
Differential Revision: https://reviews.llvm.org/D104133
This commit is contained in:
parent
473a3a773e
commit
6413226dce
|
@ -113,6 +113,15 @@ transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) {
|
|||
return true;
|
||||
}
|
||||
|
||||
/// Return true if the constant is a splat to a 2D vector so that it can be
|
||||
/// converted to a MMA constant matrix op.
|
||||
static bool constantSupportsMMAMatrixType(ConstantOp constantOp) {
|
||||
auto vecType = constantOp.getType().dyn_cast<VectorType>();
|
||||
if (!vecType || vecType.getRank() != 2)
|
||||
return false;
|
||||
return constantOp.value().isa<SplatElementsAttr>();
|
||||
}
|
||||
|
||||
static bool supportsMMaMatrixType(Operation *op) {
|
||||
if (auto transferRead = dyn_cast<vector::TransferReadOp>(op))
|
||||
return transferReadSupportsMMAMatrixType(transferRead);
|
||||
|
@ -120,6 +129,8 @@ static bool supportsMMaMatrixType(Operation *op) {
|
|||
return transferWriteSupportsMMAMatrixType(transferWrite);
|
||||
if (auto contract = dyn_cast<vector::ContractionOp>(op))
|
||||
return contractSupportsMMAMatrixType(contract);
|
||||
if (auto constant = dyn_cast<ConstantOp>(op))
|
||||
return constantSupportsMMAMatrixType(constant);
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -241,10 +252,11 @@ struct CombineTransferReadOpTranspose final
|
|||
} // namespace
|
||||
|
||||
// MMA types have different layout based on how they are used in matmul ops.
|
||||
// Figure the right layout to use by looking at Transfer op uses.
|
||||
// Figure the right layout to use by looking at op uses.
|
||||
// TODO: Change the GPU dialect to abstract the layout at the this level and
|
||||
// only care about it during lowering to NVVM.
|
||||
static const char *inferFragType(vector::TransferReadOp op) {
|
||||
template <typename OpTy>
|
||||
static const char *inferFragType(OpTy op) {
|
||||
for (Operation *users : op->getUsers()) {
|
||||
auto contract = dyn_cast<vector::ContractionOp>(users);
|
||||
if (!contract)
|
||||
|
@ -297,6 +309,23 @@ static void convertContractOp(vector::ContractionOp op,
|
|||
valueMapping[op.getResult()] = matmul;
|
||||
}
|
||||
|
||||
/// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
|
||||
static void convertConstantOp(ConstantOp op,
|
||||
llvm::DenseMap<Value, Value> &valueMapping) {
|
||||
assert(constantSupportsMMAMatrixType(op));
|
||||
OpBuilder b(op);
|
||||
Attribute splat = op.getValue().cast<SplatElementsAttr>().getSplatValue();
|
||||
auto scalarConstant =
|
||||
b.create<ConstantOp>(op.getLoc(), splat.getType(), splat);
|
||||
const char *fragType = inferFragType(op);
|
||||
auto vecType = op.getType().cast<VectorType>();
|
||||
gpu::MMAMatrixType type = gpu::MMAMatrixType::get(
|
||||
vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
|
||||
auto matrix = b.create<gpu::SubgroupMmaConstantMatrixOp>(op.getLoc(), type,
|
||||
scalarConstant);
|
||||
valueMapping[op.getResult()] = matrix;
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
|
||||
void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns) {
|
||||
|
@ -314,6 +343,8 @@ void convertVectorToMMAOps(FuncOp funcOp) {
|
|||
convertTransferWriteOp(transferWrite, valueMapping);
|
||||
} else if (auto contractOp = dyn_cast<vector::ContractionOp>(op)) {
|
||||
convertContractOp(contractOp, valueMapping);
|
||||
} else if (auto constantOp = dyn_cast<ConstantOp>(op)) {
|
||||
convertConstantOp(constantOp, valueMapping);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -23,6 +23,24 @@ func @matmul(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<1
|
|||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @matmul_cst
|
||||
// CHECK-DAG: %[[CST:.+]] = constant 0.000000e+00 : f16
|
||||
// CHECK-DAG: %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
|
||||
// CHECK-DAG: %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%c0, %c0] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
|
||||
// CHECK-DAG: %[[C:.+]] = gpu.subgroup_mma_constant_matrix %[[CST]] : !gpu.mma_matrix<16x16xf16, "COp">
|
||||
// CHECK: %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C]] : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
|
||||
// CHECK: gpu.subgroup_mma_store_matrix %[[D]], %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16>
|
||||
func @matmul_cst(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<16x16xf16>) {
|
||||
%cst_0 = constant dense<0.000000e+00> : vector<16x16xf16>
|
||||
%c0 = constant 0 : index
|
||||
%cst = constant 0.000000e+00 : f16
|
||||
%A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
|
||||
%B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
|
||||
%D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %cst_0 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
|
||||
vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16>
|
||||
return
|
||||
}
|
||||
|
||||
// Negative test until scf.for support is added.
|
||||
// CHECK-LABEL: func @matmul_loop
|
||||
// CHECK: vector.contract
|
||||
|
|
Loading…
Reference in New Issue