From 8a992b20dba54a061717a14eab86ccbe097da4c0 Mon Sep 17 00:00:00 2001 From: thomasraoux Date: Mon, 1 Nov 2021 11:43:54 -0700 Subject: [PATCH] [mlir][gpu] Add basic support to do elementwise ops on mma matrix type In order to support fusion with mma matrix type we need to be able to execute elementwise operations on them. This add an op to be able to support some basic elementwise operations. This is a is not a full solution as it only supports a limited scope or operations. Ideally we would want to be able to fuse with more kind of operations. Differential Revision: https://reviews.llvm.org/D112857 --- mlir/include/mlir/Dialect/GPU/CMakeLists.txt | 5 ++ mlir/include/mlir/Dialect/GPU/GPUBase.td | 14 ---- mlir/include/mlir/Dialect/GPU/GPUDialect.h | 2 + mlir/include/mlir/Dialect/GPU/GPUOps.td | 72 ++++++++++++++-- mlir/include/mlir/IR/OpBase.td | 8 +- .../Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp | 82 ++++++++++++++++++- mlir/lib/Dialect/GPU/CMakeLists.txt | 1 + mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 1 + .../GPUToNVVM/wmma-ops-to-nvvm.mlir | 30 +++++++ mlir/test/Dialect/GPU/ops.mlir | 5 +- .../llvm-project-overlay/mlir/BUILD.bazel | 8 ++ 11 files changed, 199 insertions(+), 29 deletions(-) diff --git a/mlir/include/mlir/Dialect/GPU/CMakeLists.txt b/mlir/include/mlir/Dialect/GPU/CMakeLists.txt index 73aa1d92ffc1..4808ec53e4e7 100644 --- a/mlir/include/mlir/Dialect/GPU/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/GPU/CMakeLists.txt @@ -22,4 +22,9 @@ mlir_tablegen(Passes.capi.h.inc -gen-pass-capi-header --prefix GPU) mlir_tablegen(Passes.capi.cpp.inc -gen-pass-capi-impl --prefix GPU) add_public_tablegen_target(MLIRGPUPassIncGen) +set(LLVM_TARGET_DEFINITIONS GPUOps.td) +mlir_tablegen(GPUOpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(GPUOpsEnums.cpp.inc -gen-enum-defs) +add_public_tablegen_target(MLIRGPUOpsEnumsGen) + add_mlir_doc(Passes GPUPasses ./ -gen-pass-doc) diff --git a/mlir/include/mlir/Dialect/GPU/GPUBase.td b/mlir/include/mlir/Dialect/GPU/GPUBase.td index a7bd8ece6a1c..6c2fa43679d2 100644 --- a/mlir/include/mlir/Dialect/GPU/GPUBase.td +++ b/mlir/include/mlir/Dialect/GPU/GPUBase.td @@ -115,18 +115,4 @@ def GPU_AsyncOpInterface : OpInterface<"AsyncOpInterface"> { ]; } -// Cases of the String enum Attribute for SubgroupMmaOpLayout, representing -// the layouts of the operands supported by the ops that use this attribute. -def RowMajor: StrEnumAttrCase<"RowMajor", 0>; -def ColMajor: StrEnumAttrCase<"ColMajor", 1>; - -// Specifies a String enum Attribute for Warp wide matrix operations, -// representing the layout of respective operands. The layout later governs -// the lowerings to appropriate intrinsics. -def SubgroupMmaOpLayout: StrEnumAttr<"Layout", "Specifies whether op is row/col major", - [RowMajor, ColMajor]> { - let stringToSymbolFnName = "LayoutStrToEnum"; - let symbolToStringFnName = "EnumToLayoutStr"; -} - #endif // GPU_BASE diff --git a/mlir/include/mlir/Dialect/GPU/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/GPUDialect.h index 79e8dca5af9c..5c1b9db33c56 100644 --- a/mlir/include/mlir/Dialect/GPU/GPUDialect.h +++ b/mlir/include/mlir/Dialect/GPU/GPUDialect.h @@ -166,6 +166,8 @@ void addAsyncDependency(Operation *op, Value token); } // end namespace gpu } // end namespace mlir +#include "mlir/Dialect/GPU/GPUOpsEnums.h.inc" + #include "mlir/Dialect/GPU/GPUOpsDialect.h.inc" #include "mlir/Dialect/GPU/GPUOpInterfaces.h.inc" diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td index b92d315b19ff..18b5adfd2445 100644 --- a/mlir/include/mlir/Dialect/GPU/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -591,13 +591,13 @@ def GPU_YieldOp : GPU_Op<"yield", [NoSideEffect, Terminator]>, } // add, mul mirror the XLA ComparisonDirection enum. -def GPU_AllReduceOpAdd : StrEnumAttrCase<"add">; -def GPU_AllReduceOpAnd : StrEnumAttrCase<"and">; -def GPU_AllReduceOpMax : StrEnumAttrCase<"max">; -def GPU_AllReduceOpMin : StrEnumAttrCase<"min">; -def GPU_AllReduceOpMul : StrEnumAttrCase<"mul">; -def GPU_AllReduceOpOr : StrEnumAttrCase<"or">; -def GPU_AllReduceOpXor : StrEnumAttrCase<"xor">; +def GPU_AllReduceOpAdd : StrEnumAttrCase<"ADD", -1, "add">; +def GPU_AllReduceOpAnd : StrEnumAttrCase<"AND", -1, "and">; +def GPU_AllReduceOpMax : StrEnumAttrCase<"MAX", -1, "max">; +def GPU_AllReduceOpMin : StrEnumAttrCase<"MIN", -1, "min">; +def GPU_AllReduceOpMul : StrEnumAttrCase<"MUL", -1, "mul">; +def GPU_AllReduceOpOr : StrEnumAttrCase<"OR", -1, "or">; +def GPU_AllReduceOpXor : StrEnumAttrCase<"XOR", -1, "xor">; def GPU_AllReduceOperationAttr : StrEnumAttr<"AllReduceOperationAttr", "built-in reduction operations supported by gpu.allreduce.", @@ -644,7 +644,7 @@ def GPU_AllReduceOp : GPU_Op<"all_reduce", let verifier = [{ return ::verifyAllReduce(*this); }]; } -def GPU_ShuffleOpXor : StrEnumAttrCase<"xor">; +def GPU_ShuffleOpXor : StrEnumAttrCase<"XOR", -1, "xor">; def GPU_ShuffleModeAttr : StrEnumAttr<"ShuffleModeAttr", "Indexing modes supported by gpu.shuffle.", @@ -1121,4 +1121,60 @@ def GPU_SubgroupMmaConstantMatrixOp : GPU_Op<"subgroup_mma_constant_matrix", }]; } +def GPU_ELEMENTWISE_OP_ADD : StrEnumAttrCase<"ADDF">; +def GPU_ELEMENTWISE_OP_MUL : StrEnumAttrCase<"MULF">; +def GPU_ELEMENTWISE_OP_MAXF : StrEnumAttrCase<"MAXF">; +def GPU_ELEMENTWISE_OP_MINF : StrEnumAttrCase<"MINF">; + +def MMAElementWiseAttr : StrEnumAttr<"MMAElementwiseOp", + "elementwise operation to apply to mma matrix", + [GPU_ELEMENTWISE_OP_ADD, GPU_ELEMENTWISE_OP_MUL, + GPU_ELEMENTWISE_OP_MAXF, GPU_ELEMENTWISE_OP_MINF]> { + let cppNamespace = "::mlir::gpu"; + let storageType = "::mlir::StringAttr"; + let returnType = "::mlir::gpu::MMAElementwiseOp"; + let convertFromStorage = "*symbolizeMMAElementwiseOp($_self.getValue())"; + let constBuilderCall = "$_builder.getStringAttr(stringifyEnum($0))"; +} + +def GPU_SubgroupMmaElementwiseOp : GPU_Op<"subgroup_mma_elementwise", + [NoSideEffect, + AllTypesMatch<["args"]>]>{ + + let summary = "GPU warp elementwise operation on a matrix"; + + let description = [{ + The `gpu.subgroup_mma_elementwise` takes `!gpu.mma_matrix` inputs and + compute a new `!gpu.mma_matrix` by applying an elementwise operation to each + element. + + Since the operation is elementwise and the matrix type must match, the + matrix elements are processed independently of the matrix layout. + + This op is meant to be used along with `gpu.subgroup_mma_compute`. + + Example: + + ```mlir + %0 = %A, %B { operation = "ADD" } : + (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) + -> !gpu.mma_matrix<16x16xf16, "COp"> + ``` + }]; + + let arguments = (ins Variadic:$args, MMAElementWiseAttr:$operation); + + let results = (outs GPU_MMAMatrix:$res); + + let extraClassDeclaration = [{ + gpu::MMAMatrixType getType() { + return res().getType().cast(); + } + }]; + + let assemblyFormat = [{ + $args attr-dict `:` functional-type($args, $res) + }]; +} + #endif // GPU_OPS diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index ec0d5355dcf1..e63a2672bf31 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1187,11 +1187,11 @@ class EnumAttrCaseInfo { } // An enum attribute case stored with StringAttr. -class StrEnumAttrCase : - EnumAttrCaseInfo, +class StrEnumAttrCase : + EnumAttrCaseInfo, StringBasedAttr< - CPred<"$_self.cast<::mlir::StringAttr>().getValue() == \"" # sym # "\"">, - "case " # sym>; + CPred<"$_self.cast<::mlir::StringAttr>().getValue() == \"" # str # "\"">, + "case " # str>; // An enum attribute case stored with IntegerAttr, which has an integer value, // its representation as a string and a C++ symbol name which may be different. diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp index 49d48bf2d630..878d0cf22fd8 100644 --- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/IR/TypeUtilities.h" using namespace mlir; @@ -352,13 +353,90 @@ struct WmmaConstantOpToNVVMLowering } }; +static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs, + Value rhs, bool isMin) { + auto floatType = getElementTypeOrSelf(lhs.getType()).cast(); + Type i1Type = builder.getI1Type(); + if (auto vecType = lhs.getType().dyn_cast()) + i1Type = VectorType::get(vecType.getShape(), i1Type); + Value cmp = builder.create( + loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt, + lhs, rhs); + Value sel = builder.create(loc, cmp, lhs, rhs); + Value isNan = builder.create( + loc, i1Type, LLVM::FCmpPredicate::uno, lhs, rhs); + Value nan = builder.create( + loc, lhs.getType(), + builder.getFloatAttr(floatType, + APFloat::getQNaN(floatType.getFloatSemantics()))); + return builder.create(loc, isNan, sel, nan); +} + +static Value createScalarOp(OpBuilder &builder, Location loc, + gpu::MMAElementwiseOp op, + ArrayRef operands) { + switch (op) { + case gpu::MMAElementwiseOp::ADDF: + return builder.create(loc, operands[0].getType(), operands); + case gpu::MMAElementwiseOp::MULF: + return builder.create(loc, operands[0].getType(), operands); + case gpu::MMAElementwiseOp::MAXF: + return createMinMaxF(builder, loc, operands[0], operands[1], + /*isMin=*/false); + case gpu::MMAElementwiseOp::MINF: + return createMinMaxF(builder, loc, operands[0], operands[1], + /*isMin=*/true); + } + llvm_unreachable("unknown op"); +} + +/// Convert GPU MMA elementwise ops to extract + op + insert. +struct WmmaElementwiseOpToNVVMLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + gpu::SubgroupMmaElementwiseOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(gpu::SubgroupMmaElementwiseOp subgroupMmaElementwiseOp, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (failed(areAllLLVMTypes(subgroupMmaElementwiseOp.getOperation(), + adaptor.getOperands(), rewriter))) + return failure(); + Location loc = subgroupMmaElementwiseOp.getLoc(); + size_t numOperands = adaptor.getOperands().size(); + LLVM::LLVMStructType destType = convertMMAToLLVMType( + subgroupMmaElementwiseOp.getType().cast()); + Value matrixStruct = rewriter.create(loc, destType); + for (size_t i = 0, e = destType.getBody().size(); i < e; ++i) { + SmallVector extractedOperands; + for (size_t opIdx = 0; opIdx < numOperands; opIdx++) { + Type elementType = adaptor.getOperands()[opIdx] + .getType() + .cast() + .getBody()[i]; + extractedOperands.push_back(rewriter.create( + loc, elementType, adaptor.getOperands()[opIdx], + rewriter.getI32ArrayAttr(i))); + } + Value element = + createScalarOp(rewriter, loc, subgroupMmaElementwiseOp.operation(), + extractedOperands); + matrixStruct = rewriter.create( + loc, matrixStruct, element, rewriter.getI32ArrayAttr(i)); + } + rewriter.replaceOp(subgroupMmaElementwiseOp, matrixStruct); + return success(); + } +}; + } // anonymous namespace namespace mlir { void populateGpuWMMAToNVVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns) { patterns.insert( - converter); + WmmaStoreOpToNVVMLowering, WmmaConstantOpToNVVMLowering, + WmmaElementwiseOpToNVVMLowering>(converter); } } // namespace mlir diff --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt index 2beb7ea7bc88..14520ce6767d 100644 --- a/mlir/lib/Dialect/GPU/CMakeLists.txt +++ b/mlir/lib/Dialect/GPU/CMakeLists.txt @@ -24,6 +24,7 @@ add_mlir_dialect_library(MLIRGPUOps DEPENDS MLIRGPUOpsIncGen + MLIRGPUOpsEnumsGen MLIRGPUOpInterfacesIncGen LINK_LIBS PUBLIC diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index ba1710b57a91..9baff7f53ca8 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -1185,6 +1185,7 @@ void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results, } #include "mlir/Dialect/GPU/GPUOpInterfaces.cpp.inc" +#include "mlir/Dialect/GPU/GPUOpsEnums.cpp.inc" #define GET_OP_CLASSES #include "mlir/Dialect/GPU/GPUOps.cpp.inc" diff --git a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir index 4c035acaf738..c0ac8a050288 100644 --- a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir +++ b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir @@ -220,3 +220,33 @@ gpu.module @test_module { return %C : !gpu.mma_matrix<16x16xf16, "COp"> } } + +// ----- + +gpu.module @test_module { + +// CHECK-LABEL: func @gpu_wmma_elementwise +// CHECK: %[[M0:.*]] = llvm.mlir.undef : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> +// CHECK: %[[A0:.*]] = llvm.extractvalue %{{.*}}[0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> +// CHECK: %[[B0:.*]] = llvm.extractvalue %{{.*}}[0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> +// CHECK: %[[C0:.*]] = llvm.fadd %[[A0]], %[[B0]] : vector<2xf16> +// CHECK: %[[M1:.*]] = llvm.insertvalue %[[C0]], %[[M0]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> +// CHECK: %[[A1:.*]] = llvm.extractvalue %{{.*}}[1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> +// CHECK: %[[B1:.*]] = llvm.extractvalue %{{.*}}[1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> +// CHECK: %[[C1:.*]] = llvm.fadd %[[A1]], %[[B1]] : vector<2xf16> +// CHECK: %[[M2:.*]] = llvm.insertvalue %[[C1]], %[[M1]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> +// CHECK: %[[A2:.*]] = llvm.extractvalue %{{.*}}[2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> +// CHECK: %[[B2:.*]] = llvm.extractvalue %{{.*}}[2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> +// CHECK: %[[C2:.*]] = llvm.fadd %[[A2]], %[[B2]] : vector<2xf16> +// CHECK: %[[M3:.*]] = llvm.insertvalue %[[C2]], %[[M2]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> +// CHECK: %[[A3:.*]] = llvm.extractvalue %{{.*}}[3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> +// CHECK: %[[B3:.*]] = llvm.extractvalue %{{.*}}[3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> +// CHECK: %[[C3:.*]] = llvm.fadd %[[A3]], %[[B3]] : vector<2xf16> +// CHECK: %[[M4:.*]] = llvm.insertvalue %[[C3]], %[[M3]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> +// CHECK: llvm.return %[[M4]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + builtin.func @gpu_wmma_elementwise(%A : !gpu.mma_matrix<16x16xf16, "COp">, %B : !gpu.mma_matrix<16x16xf16, "COp">) ->(!gpu.mma_matrix<16x16xf16, "COp">) { + %C = gpu.subgroup_mma_elementwise %A, %B { operation = "ADDF" } : + (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp"> + return %C : !gpu.mma_matrix<16x16xf16, "COp"> + } +} diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir index 297fb5fe6fe2..c24fd7bf8a81 100644 --- a/mlir/test/Dialect/GPU/ops.mlir +++ b/mlir/test/Dialect/GPU/ops.mlir @@ -220,7 +220,10 @@ module attributes {gpu.container_module} { %0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp"> // CHECK: gpu.subgroup_mma_load_matrix %[[wg]][%[[i]], %[[i]]] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp"> %1 = gpu.subgroup_mma_constant_matrix %cst : !gpu.mma_matrix<16x16xf32, "COp"> - // CHECK: gpu.subgroup_mma_constant_matrix %[[cst]] : !gpu.mma_matrix<16x16xf32, "COp"> + // CHECK: gpu.subgroup_mma_elementwise %{{.*}}, %{{.*}} {operation = "ADDF"} : (!gpu.mma_matrix<16x16xf32, "COp">, !gpu.mma_matrix<16x16xf32, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp"> + %2 = gpu.subgroup_mma_elementwise %1, %1 {operation = "ADDF"} : (!gpu.mma_matrix<16x16xf32, "COp">, !gpu.mma_matrix<16x16xf32, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp"> + // CHECK: gpu.subgroup_mma_elementwise %{{.*}}, %{{.*}} {operation = "MAXF"} : (!gpu.mma_matrix<16x16xf32, "COp">, !gpu.mma_matrix<16x16xf32, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp"> + %3 = gpu.subgroup_mma_elementwise %2, %1 {operation = "MAXF"} : (!gpu.mma_matrix<16x16xf32, "COp">, !gpu.mma_matrix<16x16xf32, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp"> return } } diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index cd1e34d1964d..a5a59eb9cd63 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -2768,6 +2768,14 @@ gentbl_cc_library( ["-gen-op-defs"], "include/mlir/Dialect/GPU/GPUOps.cpp.inc", ), + ( + ["-gen-enum-decls"], + "include/mlir/Dialect/GPU/GPUOpsEnums.h.inc", + ), + ( + ["-gen-enum-defs"], + "include/mlir/Dialect/GPU/GPUOpsEnums.cpp.inc", + ), ], tblgen = ":mlir-tblgen", td_file = "include/mlir/Dialect/GPU/GPUOps.td",