[mlir][gpu] Add basic support to do elementwise ops on mma matrix type

In order to support fusion with mma matrix type we need to be able to
execute elementwise operations on them. This add an op to be able to
support some basic elementwise operations. This is a is not a full
solution as it only supports a limited scope or operations. Ideally we would
want to be able to fuse with more kind of operations.

Differential Revision: https://reviews.llvm.org/D112857
This commit is contained in:
thomasraoux 2021-11-01 11:43:54 -07:00
parent dfa0981407
commit 8a992b20db
11 changed files with 199 additions and 29 deletions

View File

@ -22,4 +22,9 @@ mlir_tablegen(Passes.capi.h.inc -gen-pass-capi-header --prefix GPU)
mlir_tablegen(Passes.capi.cpp.inc -gen-pass-capi-impl --prefix GPU)
add_public_tablegen_target(MLIRGPUPassIncGen)
set(LLVM_TARGET_DEFINITIONS GPUOps.td)
mlir_tablegen(GPUOpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(GPUOpsEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRGPUOpsEnumsGen)
add_mlir_doc(Passes GPUPasses ./ -gen-pass-doc)

View File

@ -115,18 +115,4 @@ def GPU_AsyncOpInterface : OpInterface<"AsyncOpInterface"> {
];
}
// Cases of the String enum Attribute for SubgroupMmaOpLayout, representing
// the layouts of the operands supported by the ops that use this attribute.
def RowMajor: StrEnumAttrCase<"RowMajor", 0>;
def ColMajor: StrEnumAttrCase<"ColMajor", 1>;
// Specifies a String enum Attribute for Warp wide matrix operations,
// representing the layout of respective operands. The layout later governs
// the lowerings to appropriate intrinsics.
def SubgroupMmaOpLayout: StrEnumAttr<"Layout", "Specifies whether op is row/col major",
[RowMajor, ColMajor]> {
let stringToSymbolFnName = "LayoutStrToEnum";
let symbolToStringFnName = "EnumToLayoutStr";
}
#endif // GPU_BASE

View File

@ -166,6 +166,8 @@ void addAsyncDependency(Operation *op, Value token);
} // end namespace gpu
} // end namespace mlir
#include "mlir/Dialect/GPU/GPUOpsEnums.h.inc"
#include "mlir/Dialect/GPU/GPUOpsDialect.h.inc"
#include "mlir/Dialect/GPU/GPUOpInterfaces.h.inc"

View File

@ -591,13 +591,13 @@ def GPU_YieldOp : GPU_Op<"yield", [NoSideEffect, Terminator]>,
}
// add, mul mirror the XLA ComparisonDirection enum.
def GPU_AllReduceOpAdd : StrEnumAttrCase<"add">;
def GPU_AllReduceOpAnd : StrEnumAttrCase<"and">;
def GPU_AllReduceOpMax : StrEnumAttrCase<"max">;
def GPU_AllReduceOpMin : StrEnumAttrCase<"min">;
def GPU_AllReduceOpMul : StrEnumAttrCase<"mul">;
def GPU_AllReduceOpOr : StrEnumAttrCase<"or">;
def GPU_AllReduceOpXor : StrEnumAttrCase<"xor">;
def GPU_AllReduceOpAdd : StrEnumAttrCase<"ADD", -1, "add">;
def GPU_AllReduceOpAnd : StrEnumAttrCase<"AND", -1, "and">;
def GPU_AllReduceOpMax : StrEnumAttrCase<"MAX", -1, "max">;
def GPU_AllReduceOpMin : StrEnumAttrCase<"MIN", -1, "min">;
def GPU_AllReduceOpMul : StrEnumAttrCase<"MUL", -1, "mul">;
def GPU_AllReduceOpOr : StrEnumAttrCase<"OR", -1, "or">;
def GPU_AllReduceOpXor : StrEnumAttrCase<"XOR", -1, "xor">;
def GPU_AllReduceOperationAttr : StrEnumAttr<"AllReduceOperationAttr",
"built-in reduction operations supported by gpu.allreduce.",
@ -644,7 +644,7 @@ def GPU_AllReduceOp : GPU_Op<"all_reduce",
let verifier = [{ return ::verifyAllReduce(*this); }];
}
def GPU_ShuffleOpXor : StrEnumAttrCase<"xor">;
def GPU_ShuffleOpXor : StrEnumAttrCase<"XOR", -1, "xor">;
def GPU_ShuffleModeAttr : StrEnumAttr<"ShuffleModeAttr",
"Indexing modes supported by gpu.shuffle.",
@ -1121,4 +1121,60 @@ def GPU_SubgroupMmaConstantMatrixOp : GPU_Op<"subgroup_mma_constant_matrix",
}];
}
def GPU_ELEMENTWISE_OP_ADD : StrEnumAttrCase<"ADDF">;
def GPU_ELEMENTWISE_OP_MUL : StrEnumAttrCase<"MULF">;
def GPU_ELEMENTWISE_OP_MAXF : StrEnumAttrCase<"MAXF">;
def GPU_ELEMENTWISE_OP_MINF : StrEnumAttrCase<"MINF">;
def MMAElementWiseAttr : StrEnumAttr<"MMAElementwiseOp",
"elementwise operation to apply to mma matrix",
[GPU_ELEMENTWISE_OP_ADD, GPU_ELEMENTWISE_OP_MUL,
GPU_ELEMENTWISE_OP_MAXF, GPU_ELEMENTWISE_OP_MINF]> {
let cppNamespace = "::mlir::gpu";
let storageType = "::mlir::StringAttr";
let returnType = "::mlir::gpu::MMAElementwiseOp";
let convertFromStorage = "*symbolizeMMAElementwiseOp($_self.getValue())";
let constBuilderCall = "$_builder.getStringAttr(stringifyEnum($0))";
}
def GPU_SubgroupMmaElementwiseOp : GPU_Op<"subgroup_mma_elementwise",
[NoSideEffect,
AllTypesMatch<["args"]>]>{
let summary = "GPU warp elementwise operation on a matrix";
let description = [{
The `gpu.subgroup_mma_elementwise` takes `!gpu.mma_matrix` inputs and
compute a new `!gpu.mma_matrix` by applying an elementwise operation to each
element.
Since the operation is elementwise and the matrix type must match, the
matrix elements are processed independently of the matrix layout.
This op is meant to be used along with `gpu.subgroup_mma_compute`.
Example:
```mlir
%0 = %A, %B { operation = "ADD" } :
(!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">)
-> !gpu.mma_matrix<16x16xf16, "COp">
```
}];
let arguments = (ins Variadic<GPU_MMAMatrix>:$args, MMAElementWiseAttr:$operation);
let results = (outs GPU_MMAMatrix:$res);
let extraClassDeclaration = [{
gpu::MMAMatrixType getType() {
return res().getType().cast<gpu::MMAMatrixType>();
}
}];
let assemblyFormat = [{
$args attr-dict `:` functional-type($args, $res)
}];
}
#endif // GPU_OPS

View File

@ -1187,11 +1187,11 @@ class EnumAttrCaseInfo<string sym, int intVal, string strVal> {
}
// An enum attribute case stored with StringAttr.
class StrEnumAttrCase<string sym, int val = -1> :
EnumAttrCaseInfo<sym, val, sym>,
class StrEnumAttrCase<string sym, int val = -1, string str = sym> :
EnumAttrCaseInfo<sym, val, str>,
StringBasedAttr<
CPred<"$_self.cast<::mlir::StringAttr>().getValue() == \"" # sym # "\"">,
"case " # sym>;
CPred<"$_self.cast<::mlir::StringAttr>().getValue() == \"" # str # "\"">,
"case " # str>;
// An enum attribute case stored with IntegerAttr, which has an integer value,
// its representation as a string and a C++ symbol name which may be different.

View File

@ -15,6 +15,7 @@
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/IR/TypeUtilities.h"
using namespace mlir;
@ -352,13 +353,90 @@ struct WmmaConstantOpToNVVMLowering
}
};
static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
Value rhs, bool isMin) {
auto floatType = getElementTypeOrSelf(lhs.getType()).cast<FloatType>();
Type i1Type = builder.getI1Type();
if (auto vecType = lhs.getType().dyn_cast<VectorType>())
i1Type = VectorType::get(vecType.getShape(), i1Type);
Value cmp = builder.create<LLVM::FCmpOp>(
loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt,
lhs, rhs);
Value sel = builder.create<LLVM::SelectOp>(loc, cmp, lhs, rhs);
Value isNan = builder.create<LLVM::FCmpOp>(
loc, i1Type, LLVM::FCmpPredicate::uno, lhs, rhs);
Value nan = builder.create<LLVM::ConstantOp>(
loc, lhs.getType(),
builder.getFloatAttr(floatType,
APFloat::getQNaN(floatType.getFloatSemantics())));
return builder.create<LLVM::SelectOp>(loc, isNan, sel, nan);
}
static Value createScalarOp(OpBuilder &builder, Location loc,
gpu::MMAElementwiseOp op,
ArrayRef<Value> operands) {
switch (op) {
case gpu::MMAElementwiseOp::ADDF:
return builder.create<LLVM::FAddOp>(loc, operands[0].getType(), operands);
case gpu::MMAElementwiseOp::MULF:
return builder.create<LLVM::FMulOp>(loc, operands[0].getType(), operands);
case gpu::MMAElementwiseOp::MAXF:
return createMinMaxF(builder, loc, operands[0], operands[1],
/*isMin=*/false);
case gpu::MMAElementwiseOp::MINF:
return createMinMaxF(builder, loc, operands[0], operands[1],
/*isMin=*/true);
}
llvm_unreachable("unknown op");
}
/// Convert GPU MMA elementwise ops to extract + op + insert.
struct WmmaElementwiseOpToNVVMLowering
: public ConvertOpToLLVMPattern<gpu::SubgroupMmaElementwiseOp> {
using ConvertOpToLLVMPattern<
gpu::SubgroupMmaElementwiseOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(gpu::SubgroupMmaElementwiseOp subgroupMmaElementwiseOp,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(areAllLLVMTypes(subgroupMmaElementwiseOp.getOperation(),
adaptor.getOperands(), rewriter)))
return failure();
Location loc = subgroupMmaElementwiseOp.getLoc();
size_t numOperands = adaptor.getOperands().size();
LLVM::LLVMStructType destType = convertMMAToLLVMType(
subgroupMmaElementwiseOp.getType().cast<gpu::MMAMatrixType>());
Value matrixStruct = rewriter.create<LLVM::UndefOp>(loc, destType);
for (size_t i = 0, e = destType.getBody().size(); i < e; ++i) {
SmallVector<Value> extractedOperands;
for (size_t opIdx = 0; opIdx < numOperands; opIdx++) {
Type elementType = adaptor.getOperands()[opIdx]
.getType()
.cast<LLVM::LLVMStructType>()
.getBody()[i];
extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>(
loc, elementType, adaptor.getOperands()[opIdx],
rewriter.getI32ArrayAttr(i)));
}
Value element =
createScalarOp(rewriter, loc, subgroupMmaElementwiseOp.operation(),
extractedOperands);
matrixStruct = rewriter.create<LLVM::InsertValueOp>(
loc, matrixStruct, element, rewriter.getI32ArrayAttr(i));
}
rewriter.replaceOp(subgroupMmaElementwiseOp, matrixStruct);
return success();
}
};
} // anonymous namespace
namespace mlir {
void populateGpuWMMAToNVVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns) {
patterns.insert<WmmaLoadOpToNVVMLowering, WmmaMmaOpToNVVMLowering,
WmmaStoreOpToNVVMLowering, WmmaConstantOpToNVVMLowering>(
converter);
WmmaStoreOpToNVVMLowering, WmmaConstantOpToNVVMLowering,
WmmaElementwiseOpToNVVMLowering>(converter);
}
} // namespace mlir

View File

@ -24,6 +24,7 @@ add_mlir_dialect_library(MLIRGPUOps
DEPENDS
MLIRGPUOpsIncGen
MLIRGPUOpsEnumsGen
MLIRGPUOpInterfacesIncGen
LINK_LIBS PUBLIC

View File

@ -1185,6 +1185,7 @@ void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
}
#include "mlir/Dialect/GPU/GPUOpInterfaces.cpp.inc"
#include "mlir/Dialect/GPU/GPUOpsEnums.cpp.inc"
#define GET_OP_CLASSES
#include "mlir/Dialect/GPU/GPUOps.cpp.inc"

View File

@ -220,3 +220,33 @@ gpu.module @test_module {
return %C : !gpu.mma_matrix<16x16xf16, "COp">
}
}
// -----
gpu.module @test_module {
// CHECK-LABEL: func @gpu_wmma_elementwise
// CHECK: %[[M0:.*]] = llvm.mlir.undef : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[A0:.*]] = llvm.extractvalue %{{.*}}[0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[B0:.*]] = llvm.extractvalue %{{.*}}[0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[C0:.*]] = llvm.fadd %[[A0]], %[[B0]] : vector<2xf16>
// CHECK: %[[M1:.*]] = llvm.insertvalue %[[C0]], %[[M0]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[A1:.*]] = llvm.extractvalue %{{.*}}[1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[B1:.*]] = llvm.extractvalue %{{.*}}[1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[C1:.*]] = llvm.fadd %[[A1]], %[[B1]] : vector<2xf16>
// CHECK: %[[M2:.*]] = llvm.insertvalue %[[C1]], %[[M1]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[A2:.*]] = llvm.extractvalue %{{.*}}[2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[B2:.*]] = llvm.extractvalue %{{.*}}[2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[C2:.*]] = llvm.fadd %[[A2]], %[[B2]] : vector<2xf16>
// CHECK: %[[M3:.*]] = llvm.insertvalue %[[C2]], %[[M2]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[A3:.*]] = llvm.extractvalue %{{.*}}[3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[B3:.*]] = llvm.extractvalue %{{.*}}[3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: %[[C3:.*]] = llvm.fadd %[[A3]], %[[B3]] : vector<2xf16>
// CHECK: %[[M4:.*]] = llvm.insertvalue %[[C3]], %[[M3]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
// CHECK: llvm.return %[[M4]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
builtin.func @gpu_wmma_elementwise(%A : !gpu.mma_matrix<16x16xf16, "COp">, %B : !gpu.mma_matrix<16x16xf16, "COp">) ->(!gpu.mma_matrix<16x16xf16, "COp">) {
%C = gpu.subgroup_mma_elementwise %A, %B { operation = "ADDF" } :
(!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
return %C : !gpu.mma_matrix<16x16xf16, "COp">
}
}

View File

@ -220,7 +220,10 @@ module attributes {gpu.container_module} {
%0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp">
// CHECK: gpu.subgroup_mma_load_matrix %[[wg]][%[[i]], %[[i]]] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp">
%1 = gpu.subgroup_mma_constant_matrix %cst : !gpu.mma_matrix<16x16xf32, "COp">
// CHECK: gpu.subgroup_mma_constant_matrix %[[cst]] : !gpu.mma_matrix<16x16xf32, "COp">
// CHECK: gpu.subgroup_mma_elementwise %{{.*}}, %{{.*}} {operation = "ADDF"} : (!gpu.mma_matrix<16x16xf32, "COp">, !gpu.mma_matrix<16x16xf32, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp">
%2 = gpu.subgroup_mma_elementwise %1, %1 {operation = "ADDF"} : (!gpu.mma_matrix<16x16xf32, "COp">, !gpu.mma_matrix<16x16xf32, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp">
// CHECK: gpu.subgroup_mma_elementwise %{{.*}}, %{{.*}} {operation = "MAXF"} : (!gpu.mma_matrix<16x16xf32, "COp">, !gpu.mma_matrix<16x16xf32, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp">
%3 = gpu.subgroup_mma_elementwise %2, %1 {operation = "MAXF"} : (!gpu.mma_matrix<16x16xf32, "COp">, !gpu.mma_matrix<16x16xf32, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp">
return
}
}

View File

@ -2768,6 +2768,14 @@ gentbl_cc_library(
["-gen-op-defs"],
"include/mlir/Dialect/GPU/GPUOps.cpp.inc",
),
(
["-gen-enum-decls"],
"include/mlir/Dialect/GPU/GPUOpsEnums.h.inc",
),
(
["-gen-enum-defs"],
"include/mlir/Dialect/GPU/GPUOpsEnums.cpp.inc",
),
],
tblgen = ":mlir-tblgen",
td_file = "include/mlir/Dialect/GPU/GPUOps.td",