forked from OSchip/llvm-project
[mlir][gpu] Add basic support to do elementwise ops on mma matrix type
In order to support fusion with mma matrix type we need to be able to execute elementwise operations on them. This add an op to be able to support some basic elementwise operations. This is a is not a full solution as it only supports a limited scope or operations. Ideally we would want to be able to fuse with more kind of operations. Differential Revision: https://reviews.llvm.org/D112857
This commit is contained in:
parent
dfa0981407
commit
8a992b20db
|
@ -22,4 +22,9 @@ mlir_tablegen(Passes.capi.h.inc -gen-pass-capi-header --prefix GPU)
|
|||
mlir_tablegen(Passes.capi.cpp.inc -gen-pass-capi-impl --prefix GPU)
|
||||
add_public_tablegen_target(MLIRGPUPassIncGen)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS GPUOps.td)
|
||||
mlir_tablegen(GPUOpsEnums.h.inc -gen-enum-decls)
|
||||
mlir_tablegen(GPUOpsEnums.cpp.inc -gen-enum-defs)
|
||||
add_public_tablegen_target(MLIRGPUOpsEnumsGen)
|
||||
|
||||
add_mlir_doc(Passes GPUPasses ./ -gen-pass-doc)
|
||||
|
|
|
@ -115,18 +115,4 @@ def GPU_AsyncOpInterface : OpInterface<"AsyncOpInterface"> {
|
|||
];
|
||||
}
|
||||
|
||||
// Cases of the String enum Attribute for SubgroupMmaOpLayout, representing
|
||||
// the layouts of the operands supported by the ops that use this attribute.
|
||||
def RowMajor: StrEnumAttrCase<"RowMajor", 0>;
|
||||
def ColMajor: StrEnumAttrCase<"ColMajor", 1>;
|
||||
|
||||
// Specifies a String enum Attribute for Warp wide matrix operations,
|
||||
// representing the layout of respective operands. The layout later governs
|
||||
// the lowerings to appropriate intrinsics.
|
||||
def SubgroupMmaOpLayout: StrEnumAttr<"Layout", "Specifies whether op is row/col major",
|
||||
[RowMajor, ColMajor]> {
|
||||
let stringToSymbolFnName = "LayoutStrToEnum";
|
||||
let symbolToStringFnName = "EnumToLayoutStr";
|
||||
}
|
||||
|
||||
#endif // GPU_BASE
|
||||
|
|
|
@ -166,6 +166,8 @@ void addAsyncDependency(Operation *op, Value token);
|
|||
} // end namespace gpu
|
||||
} // end namespace mlir
|
||||
|
||||
#include "mlir/Dialect/GPU/GPUOpsEnums.h.inc"
|
||||
|
||||
#include "mlir/Dialect/GPU/GPUOpsDialect.h.inc"
|
||||
|
||||
#include "mlir/Dialect/GPU/GPUOpInterfaces.h.inc"
|
||||
|
|
|
@ -591,13 +591,13 @@ def GPU_YieldOp : GPU_Op<"yield", [NoSideEffect, Terminator]>,
|
|||
}
|
||||
|
||||
// add, mul mirror the XLA ComparisonDirection enum.
|
||||
def GPU_AllReduceOpAdd : StrEnumAttrCase<"add">;
|
||||
def GPU_AllReduceOpAnd : StrEnumAttrCase<"and">;
|
||||
def GPU_AllReduceOpMax : StrEnumAttrCase<"max">;
|
||||
def GPU_AllReduceOpMin : StrEnumAttrCase<"min">;
|
||||
def GPU_AllReduceOpMul : StrEnumAttrCase<"mul">;
|
||||
def GPU_AllReduceOpOr : StrEnumAttrCase<"or">;
|
||||
def GPU_AllReduceOpXor : StrEnumAttrCase<"xor">;
|
||||
def GPU_AllReduceOpAdd : StrEnumAttrCase<"ADD", -1, "add">;
|
||||
def GPU_AllReduceOpAnd : StrEnumAttrCase<"AND", -1, "and">;
|
||||
def GPU_AllReduceOpMax : StrEnumAttrCase<"MAX", -1, "max">;
|
||||
def GPU_AllReduceOpMin : StrEnumAttrCase<"MIN", -1, "min">;
|
||||
def GPU_AllReduceOpMul : StrEnumAttrCase<"MUL", -1, "mul">;
|
||||
def GPU_AllReduceOpOr : StrEnumAttrCase<"OR", -1, "or">;
|
||||
def GPU_AllReduceOpXor : StrEnumAttrCase<"XOR", -1, "xor">;
|
||||
|
||||
def GPU_AllReduceOperationAttr : StrEnumAttr<"AllReduceOperationAttr",
|
||||
"built-in reduction operations supported by gpu.allreduce.",
|
||||
|
@ -644,7 +644,7 @@ def GPU_AllReduceOp : GPU_Op<"all_reduce",
|
|||
let verifier = [{ return ::verifyAllReduce(*this); }];
|
||||
}
|
||||
|
||||
def GPU_ShuffleOpXor : StrEnumAttrCase<"xor">;
|
||||
def GPU_ShuffleOpXor : StrEnumAttrCase<"XOR", -1, "xor">;
|
||||
|
||||
def GPU_ShuffleModeAttr : StrEnumAttr<"ShuffleModeAttr",
|
||||
"Indexing modes supported by gpu.shuffle.",
|
||||
|
@ -1121,4 +1121,60 @@ def GPU_SubgroupMmaConstantMatrixOp : GPU_Op<"subgroup_mma_constant_matrix",
|
|||
}];
|
||||
}
|
||||
|
||||
def GPU_ELEMENTWISE_OP_ADD : StrEnumAttrCase<"ADDF">;
|
||||
def GPU_ELEMENTWISE_OP_MUL : StrEnumAttrCase<"MULF">;
|
||||
def GPU_ELEMENTWISE_OP_MAXF : StrEnumAttrCase<"MAXF">;
|
||||
def GPU_ELEMENTWISE_OP_MINF : StrEnumAttrCase<"MINF">;
|
||||
|
||||
def MMAElementWiseAttr : StrEnumAttr<"MMAElementwiseOp",
|
||||
"elementwise operation to apply to mma matrix",
|
||||
[GPU_ELEMENTWISE_OP_ADD, GPU_ELEMENTWISE_OP_MUL,
|
||||
GPU_ELEMENTWISE_OP_MAXF, GPU_ELEMENTWISE_OP_MINF]> {
|
||||
let cppNamespace = "::mlir::gpu";
|
||||
let storageType = "::mlir::StringAttr";
|
||||
let returnType = "::mlir::gpu::MMAElementwiseOp";
|
||||
let convertFromStorage = "*symbolizeMMAElementwiseOp($_self.getValue())";
|
||||
let constBuilderCall = "$_builder.getStringAttr(stringifyEnum($0))";
|
||||
}
|
||||
|
||||
def GPU_SubgroupMmaElementwiseOp : GPU_Op<"subgroup_mma_elementwise",
|
||||
[NoSideEffect,
|
||||
AllTypesMatch<["args"]>]>{
|
||||
|
||||
let summary = "GPU warp elementwise operation on a matrix";
|
||||
|
||||
let description = [{
|
||||
The `gpu.subgroup_mma_elementwise` takes `!gpu.mma_matrix` inputs and
|
||||
compute a new `!gpu.mma_matrix` by applying an elementwise operation to each
|
||||
element.
|
||||
|
||||
Since the operation is elementwise and the matrix type must match, the
|
||||
matrix elements are processed independently of the matrix layout.
|
||||
|
||||
This op is meant to be used along with `gpu.subgroup_mma_compute`.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%0 = %A, %B { operation = "ADD" } :
|
||||
(!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">)
|
||||
-> !gpu.mma_matrix<16x16xf16, "COp">
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins Variadic<GPU_MMAMatrix>:$args, MMAElementWiseAttr:$operation);
|
||||
|
||||
let results = (outs GPU_MMAMatrix:$res);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
gpu::MMAMatrixType getType() {
|
||||
return res().getType().cast<gpu::MMAMatrixType>();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
$args attr-dict `:` functional-type($args, $res)
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // GPU_OPS
|
||||
|
|
|
@ -1187,11 +1187,11 @@ class EnumAttrCaseInfo<string sym, int intVal, string strVal> {
|
|||
}
|
||||
|
||||
// An enum attribute case stored with StringAttr.
|
||||
class StrEnumAttrCase<string sym, int val = -1> :
|
||||
EnumAttrCaseInfo<sym, val, sym>,
|
||||
class StrEnumAttrCase<string sym, int val = -1, string str = sym> :
|
||||
EnumAttrCaseInfo<sym, val, str>,
|
||||
StringBasedAttr<
|
||||
CPred<"$_self.cast<::mlir::StringAttr>().getValue() == \"" # sym # "\"">,
|
||||
"case " # sym>;
|
||||
CPred<"$_self.cast<::mlir::StringAttr>().getValue() == \"" # str # "\"">,
|
||||
"case " # str>;
|
||||
|
||||
// An enum attribute case stored with IntegerAttr, which has an integer value,
|
||||
// its representation as a string and a C++ symbol name which may be different.
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
|
@ -352,13 +353,90 @@ struct WmmaConstantOpToNVVMLowering
|
|||
}
|
||||
};
|
||||
|
||||
static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
|
||||
Value rhs, bool isMin) {
|
||||
auto floatType = getElementTypeOrSelf(lhs.getType()).cast<FloatType>();
|
||||
Type i1Type = builder.getI1Type();
|
||||
if (auto vecType = lhs.getType().dyn_cast<VectorType>())
|
||||
i1Type = VectorType::get(vecType.getShape(), i1Type);
|
||||
Value cmp = builder.create<LLVM::FCmpOp>(
|
||||
loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt,
|
||||
lhs, rhs);
|
||||
Value sel = builder.create<LLVM::SelectOp>(loc, cmp, lhs, rhs);
|
||||
Value isNan = builder.create<LLVM::FCmpOp>(
|
||||
loc, i1Type, LLVM::FCmpPredicate::uno, lhs, rhs);
|
||||
Value nan = builder.create<LLVM::ConstantOp>(
|
||||
loc, lhs.getType(),
|
||||
builder.getFloatAttr(floatType,
|
||||
APFloat::getQNaN(floatType.getFloatSemantics())));
|
||||
return builder.create<LLVM::SelectOp>(loc, isNan, sel, nan);
|
||||
}
|
||||
|
||||
static Value createScalarOp(OpBuilder &builder, Location loc,
|
||||
gpu::MMAElementwiseOp op,
|
||||
ArrayRef<Value> operands) {
|
||||
switch (op) {
|
||||
case gpu::MMAElementwiseOp::ADDF:
|
||||
return builder.create<LLVM::FAddOp>(loc, operands[0].getType(), operands);
|
||||
case gpu::MMAElementwiseOp::MULF:
|
||||
return builder.create<LLVM::FMulOp>(loc, operands[0].getType(), operands);
|
||||
case gpu::MMAElementwiseOp::MAXF:
|
||||
return createMinMaxF(builder, loc, operands[0], operands[1],
|
||||
/*isMin=*/false);
|
||||
case gpu::MMAElementwiseOp::MINF:
|
||||
return createMinMaxF(builder, loc, operands[0], operands[1],
|
||||
/*isMin=*/true);
|
||||
}
|
||||
llvm_unreachable("unknown op");
|
||||
}
|
||||
|
||||
/// Convert GPU MMA elementwise ops to extract + op + insert.
|
||||
struct WmmaElementwiseOpToNVVMLowering
|
||||
: public ConvertOpToLLVMPattern<gpu::SubgroupMmaElementwiseOp> {
|
||||
using ConvertOpToLLVMPattern<
|
||||
gpu::SubgroupMmaElementwiseOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(gpu::SubgroupMmaElementwiseOp subgroupMmaElementwiseOp,
|
||||
OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (failed(areAllLLVMTypes(subgroupMmaElementwiseOp.getOperation(),
|
||||
adaptor.getOperands(), rewriter)))
|
||||
return failure();
|
||||
Location loc = subgroupMmaElementwiseOp.getLoc();
|
||||
size_t numOperands = adaptor.getOperands().size();
|
||||
LLVM::LLVMStructType destType = convertMMAToLLVMType(
|
||||
subgroupMmaElementwiseOp.getType().cast<gpu::MMAMatrixType>());
|
||||
Value matrixStruct = rewriter.create<LLVM::UndefOp>(loc, destType);
|
||||
for (size_t i = 0, e = destType.getBody().size(); i < e; ++i) {
|
||||
SmallVector<Value> extractedOperands;
|
||||
for (size_t opIdx = 0; opIdx < numOperands; opIdx++) {
|
||||
Type elementType = adaptor.getOperands()[opIdx]
|
||||
.getType()
|
||||
.cast<LLVM::LLVMStructType>()
|
||||
.getBody()[i];
|
||||
extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, elementType, adaptor.getOperands()[opIdx],
|
||||
rewriter.getI32ArrayAttr(i)));
|
||||
}
|
||||
Value element =
|
||||
createScalarOp(rewriter, loc, subgroupMmaElementwiseOp.operation(),
|
||||
extractedOperands);
|
||||
matrixStruct = rewriter.create<LLVM::InsertValueOp>(
|
||||
loc, matrixStruct, element, rewriter.getI32ArrayAttr(i));
|
||||
}
|
||||
rewriter.replaceOp(subgroupMmaElementwiseOp, matrixStruct);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
namespace mlir {
|
||||
void populateGpuWMMAToNVVMConversionPatterns(LLVMTypeConverter &converter,
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.insert<WmmaLoadOpToNVVMLowering, WmmaMmaOpToNVVMLowering,
|
||||
WmmaStoreOpToNVVMLowering, WmmaConstantOpToNVVMLowering>(
|
||||
converter);
|
||||
WmmaStoreOpToNVVMLowering, WmmaConstantOpToNVVMLowering,
|
||||
WmmaElementwiseOpToNVVMLowering>(converter);
|
||||
}
|
||||
} // namespace mlir
|
||||
|
|
|
@ -24,6 +24,7 @@ add_mlir_dialect_library(MLIRGPUOps
|
|||
|
||||
DEPENDS
|
||||
MLIRGPUOpsIncGen
|
||||
MLIRGPUOpsEnumsGen
|
||||
MLIRGPUOpInterfacesIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
|
|
|
@ -1185,6 +1185,7 @@ void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|||
}
|
||||
|
||||
#include "mlir/Dialect/GPU/GPUOpInterfaces.cpp.inc"
|
||||
#include "mlir/Dialect/GPU/GPUOpsEnums.cpp.inc"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/GPU/GPUOps.cpp.inc"
|
||||
|
|
|
@ -220,3 +220,33 @@ gpu.module @test_module {
|
|||
return %C : !gpu.mma_matrix<16x16xf16, "COp">
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
gpu.module @test_module {
|
||||
|
||||
// CHECK-LABEL: func @gpu_wmma_elementwise
|
||||
// CHECK: %[[M0:.*]] = llvm.mlir.undef : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %[[A0:.*]] = llvm.extractvalue %{{.*}}[0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %[[B0:.*]] = llvm.extractvalue %{{.*}}[0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %[[C0:.*]] = llvm.fadd %[[A0]], %[[B0]] : vector<2xf16>
|
||||
// CHECK: %[[M1:.*]] = llvm.insertvalue %[[C0]], %[[M0]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %[[A1:.*]] = llvm.extractvalue %{{.*}}[1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %[[B1:.*]] = llvm.extractvalue %{{.*}}[1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %[[C1:.*]] = llvm.fadd %[[A1]], %[[B1]] : vector<2xf16>
|
||||
// CHECK: %[[M2:.*]] = llvm.insertvalue %[[C1]], %[[M1]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %[[A2:.*]] = llvm.extractvalue %{{.*}}[2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %[[B2:.*]] = llvm.extractvalue %{{.*}}[2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %[[C2:.*]] = llvm.fadd %[[A2]], %[[B2]] : vector<2xf16>
|
||||
// CHECK: %[[M3:.*]] = llvm.insertvalue %[[C2]], %[[M2]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %[[A3:.*]] = llvm.extractvalue %{{.*}}[3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %[[B3:.*]] = llvm.extractvalue %{{.*}}[3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: %[[C3:.*]] = llvm.fadd %[[A3]], %[[B3]] : vector<2xf16>
|
||||
// CHECK: %[[M4:.*]] = llvm.insertvalue %[[C3]], %[[M3]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
// CHECK: llvm.return %[[M4]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
|
||||
builtin.func @gpu_wmma_elementwise(%A : !gpu.mma_matrix<16x16xf16, "COp">, %B : !gpu.mma_matrix<16x16xf16, "COp">) ->(!gpu.mma_matrix<16x16xf16, "COp">) {
|
||||
%C = gpu.subgroup_mma_elementwise %A, %B { operation = "ADDF" } :
|
||||
(!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
|
||||
return %C : !gpu.mma_matrix<16x16xf16, "COp">
|
||||
}
|
||||
}
|
||||
|
|
|
@ -220,7 +220,10 @@ module attributes {gpu.container_module} {
|
|||
%0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp">
|
||||
// CHECK: gpu.subgroup_mma_load_matrix %[[wg]][%[[i]], %[[i]]] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp">
|
||||
%1 = gpu.subgroup_mma_constant_matrix %cst : !gpu.mma_matrix<16x16xf32, "COp">
|
||||
// CHECK: gpu.subgroup_mma_constant_matrix %[[cst]] : !gpu.mma_matrix<16x16xf32, "COp">
|
||||
// CHECK: gpu.subgroup_mma_elementwise %{{.*}}, %{{.*}} {operation = "ADDF"} : (!gpu.mma_matrix<16x16xf32, "COp">, !gpu.mma_matrix<16x16xf32, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp">
|
||||
%2 = gpu.subgroup_mma_elementwise %1, %1 {operation = "ADDF"} : (!gpu.mma_matrix<16x16xf32, "COp">, !gpu.mma_matrix<16x16xf32, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp">
|
||||
// CHECK: gpu.subgroup_mma_elementwise %{{.*}}, %{{.*}} {operation = "MAXF"} : (!gpu.mma_matrix<16x16xf32, "COp">, !gpu.mma_matrix<16x16xf32, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp">
|
||||
%3 = gpu.subgroup_mma_elementwise %2, %1 {operation = "MAXF"} : (!gpu.mma_matrix<16x16xf32, "COp">, !gpu.mma_matrix<16x16xf32, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp">
|
||||
return
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2768,6 +2768,14 @@ gentbl_cc_library(
|
|||
["-gen-op-defs"],
|
||||
"include/mlir/Dialect/GPU/GPUOps.cpp.inc",
|
||||
),
|
||||
(
|
||||
["-gen-enum-decls"],
|
||||
"include/mlir/Dialect/GPU/GPUOpsEnums.h.inc",
|
||||
),
|
||||
(
|
||||
["-gen-enum-defs"],
|
||||
"include/mlir/Dialect/GPU/GPUOpsEnums.cpp.inc",
|
||||
),
|
||||
],
|
||||
tblgen = ":mlir-tblgen",
|
||||
td_file = "include/mlir/Dialect/GPU/GPUOps.td",
|
||||
|
|
Loading…
Reference in New Issue